blob: 5c78d984a9c894a150d5da3210c4b0d488738797 [file] [log] [blame]
//===- MIR2Vec.cpp - Implementation of MIR2Vec ---------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
// Exceptions. See the LICENSE file for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
///
/// \file
/// This file implements the MIR2Vec algorithm for Machine IR embeddings.
///
//===----------------------------------------------------------------------===//
#include "llvm/CodeGen/MIR2Vec.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/CodeGen/TargetInstrInfo.h"
#include "llvm/IR/Module.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
#include "llvm/Support/Errc.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/Regex.h"
using namespace llvm;
using namespace mir2vec;
#define DEBUG_TYPE "mir2vec"
STATISTIC(MIRVocabMissCounter,
"Number of lookups to MIR entities not present in the vocabulary");
cl::OptionCategory llvm::mir2vec::MIR2VecCategory("MIR2Vec Options");
// FIXME: Use a default vocab when not specified
static cl::opt<std::string>
VocabFile("mir2vec-vocab-path", cl::Optional,
cl::desc("Path to the vocabulary file for MIR2Vec"), cl::init(""),
cl::cat(MIR2VecCategory));
cl::opt<float>
llvm::mir2vec::OpcWeight("mir2vec-opc-weight", cl::Optional, cl::init(1.0),
cl::desc("Weight for machine opcode embeddings"),
cl::cat(MIR2VecCategory));
//===----------------------------------------------------------------------===//
// Vocabulary Implementation
//===----------------------------------------------------------------------===//
MIRVocabulary::MIRVocabulary(VocabMap &&OpcodeEntries,
const TargetInstrInfo &TII)
: TII(TII) {
buildCanonicalOpcodeMapping();
unsigned CanonicalOpcodeCount = UniqueBaseOpcodeNames.size();
assert(CanonicalOpcodeCount > 0 &&
"No canonical opcodes found for target - invalid vocabulary");
Layout.OperandBase = CanonicalOpcodeCount;
generateStorage(OpcodeEntries);
Layout.TotalEntries = Storage.size();
}
Expected<MIRVocabulary> MIRVocabulary::create(VocabMap &&Entries,
const TargetInstrInfo &TII) {
if (Entries.empty())
return createStringError(errc::invalid_argument,
"Empty vocabulary entries provided");
return MIRVocabulary(std::move(Entries), TII);
}
std::string MIRVocabulary::extractBaseOpcodeName(StringRef InstrName) {
// Extract base instruction name using regex to capture letters and
// underscores Examples: "ADD32rr" -> "ADD", "ARITH_FENCE" -> "ARITH_FENCE"
//
// TODO: Consider more sophisticated extraction:
// - Handle complex prefixes like "AVX1_SETALLONES" correctly (Currently, it
// would naively map to "AVX")
// - Extract width suffixes (8,16,32,64) as separate features
// - Capture addressing mode suffixes (r,i,m,ri,etc.) for better analysis
// (Currently, instances like "MOV32mi" map to "MOV", but "ADDPDrr" would map
// to "ADDPDrr")
assert(!InstrName.empty() && "Instruction name should not be empty");
// Use regex to extract initial sequence of letters and underscores
static const Regex BaseOpcodeRegex("([a-zA-Z_]+)");
SmallVector<StringRef, 2> Matches;
if (BaseOpcodeRegex.match(InstrName, &Matches) && Matches.size() > 1) {
StringRef Match = Matches[1];
// Trim trailing underscores
while (!Match.empty() && Match.back() == '_')
Match = Match.drop_back();
return Match.str();
}
// Fallback to original name if no pattern matches
return InstrName.str();
}
unsigned MIRVocabulary::getCanonicalIndexForBaseName(StringRef BaseName) const {
assert(!UniqueBaseOpcodeNames.empty() && "Canonical mapping not built");
auto It = std::find(UniqueBaseOpcodeNames.begin(),
UniqueBaseOpcodeNames.end(), BaseName.str());
assert(It != UniqueBaseOpcodeNames.end() &&
"Base name not found in unique opcodes");
return std::distance(UniqueBaseOpcodeNames.begin(), It);
}
unsigned MIRVocabulary::getCanonicalOpcodeIndex(unsigned Opcode) const {
auto BaseOpcode = extractBaseOpcodeName(TII.getName(Opcode));
return getCanonicalIndexForBaseName(BaseOpcode);
}
std::string MIRVocabulary::getStringKey(unsigned Pos) const {
assert(Pos < Layout.TotalEntries && "Position out of bounds in vocabulary");
// For now, all entries are opcodes since we only have one section
if (Pos < Layout.OperandBase && Pos < UniqueBaseOpcodeNames.size()) {
// Convert canonical index back to base opcode name
auto It = UniqueBaseOpcodeNames.begin();
std::advance(It, Pos);
return *It;
}
llvm_unreachable("Invalid position in vocabulary");
return "";
}
void MIRVocabulary::generateStorage(const VocabMap &OpcodeMap) {
// Helper for handling missing entities in the vocabulary.
// Currently, we use a zero vector. In the future, we will throw an error to
// ensure that *all* known entities are present in the vocabulary.
auto handleMissingEntity = [](StringRef Key) {
LLVM_DEBUG(errs() << "MIR2Vec: Missing vocabulary entry for " << Key
<< "; using zero vector. This will result in an error "
"in the future.\n");
++MIRVocabMissCounter;
};
// Initialize opcode embeddings section
unsigned EmbeddingDim = OpcodeMap.begin()->second.size();
std::vector<Embedding> OpcodeEmbeddings(Layout.OperandBase,
Embedding(EmbeddingDim));
// Populate opcode embeddings using canonical mapping
for (auto COpcodeName : UniqueBaseOpcodeNames) {
if (auto It = OpcodeMap.find(COpcodeName); It != OpcodeMap.end()) {
auto COpcodeIndex = getCanonicalIndexForBaseName(COpcodeName);
assert(COpcodeIndex < Layout.OperandBase &&
"Canonical index out of bounds");
OpcodeEmbeddings[COpcodeIndex] = It->second;
} else {
handleMissingEntity(COpcodeName);
}
}
// TODO: Add operand/argument embeddings as additional sections
// This will require extending the vocabulary format and layout
// Scale the vocabulary sections based on the provided weights
auto scaleVocabSection = [](std::vector<Embedding> &Embeddings,
double Weight) {
for (auto &Embedding : Embeddings)
Embedding *= Weight;
};
scaleVocabSection(OpcodeEmbeddings, OpcWeight);
std::vector<std::vector<Embedding>> Sections(1);
Sections[0] = std::move(OpcodeEmbeddings);
Storage = ir2vec::VocabStorage(std::move(Sections));
}
void MIRVocabulary::buildCanonicalOpcodeMapping() {
// Check if already built
if (!UniqueBaseOpcodeNames.empty())
return;
// Build mapping from opcodes to canonical base opcode indices
for (unsigned Opcode = 0; Opcode < TII.getNumOpcodes(); ++Opcode) {
std::string BaseOpcode = extractBaseOpcodeName(TII.getName(Opcode));
UniqueBaseOpcodeNames.insert(BaseOpcode);
}
LLVM_DEBUG(dbgs() << "MIR2Vec: Built canonical mapping for target with "
<< UniqueBaseOpcodeNames.size()
<< " unique base opcodes\n");
}
//===----------------------------------------------------------------------===//
// MIR2VecVocabLegacyAnalysis Implementation
//===----------------------------------------------------------------------===//
char MIR2VecVocabLegacyAnalysis::ID = 0;
INITIALIZE_PASS_BEGIN(MIR2VecVocabLegacyAnalysis, "mir2vec-vocab-analysis",
"MIR2Vec Vocabulary Analysis", false, true)
INITIALIZE_PASS_DEPENDENCY(MachineModuleInfoWrapperPass)
INITIALIZE_PASS_END(MIR2VecVocabLegacyAnalysis, "mir2vec-vocab-analysis",
"MIR2Vec Vocabulary Analysis", false, true)
StringRef MIR2VecVocabLegacyAnalysis::getPassName() const {
return "MIR2Vec Vocabulary Analysis";
}
Error MIR2VecVocabLegacyAnalysis::readVocabulary() {
// TODO: Extend vocabulary format to support multiple sections
// (opcodes, operands, etc.) similar to IR2Vec structure
if (VocabFile.empty())
return createStringError(
errc::invalid_argument,
"MIR2Vec vocabulary file path not specified; set it "
"using --mir2vec-vocab-path");
auto BufOrError = MemoryBuffer::getFileOrSTDIN(VocabFile, /*IsText=*/true);
if (!BufOrError)
return createFileError(VocabFile, BufOrError.getError());
auto Content = BufOrError.get()->getBuffer();
Expected<json::Value> ParsedVocabValue = json::parse(Content);
if (!ParsedVocabValue)
return ParsedVocabValue.takeError();
unsigned Dim = 0;
if (auto Err = ir2vec::VocabStorage::parseVocabSection(
"entities", *ParsedVocabValue, StrVocabMap, Dim))
return Err;
return Error::success();
}
Expected<mir2vec::MIRVocabulary>
MIR2VecVocabLegacyAnalysis::getMIR2VecVocabulary(const Module &M) {
if (StrVocabMap.empty()) {
if (Error Err = readVocabulary()) {
return std::move(Err);
}
}
// Get machine module info to access machine functions and target info
MachineModuleInfo &MMI = getAnalysis<MachineModuleInfoWrapperPass>().getMMI();
// Find first available machine function to get target instruction info
for (const auto &F : M) {
if (F.isDeclaration())
continue;
if (auto *MF = MMI.getMachineFunction(F)) {
const TargetInstrInfo *TII = MF->getSubtarget().getInstrInfo();
return mir2vec::MIRVocabulary::create(std::move(StrVocabMap), *TII);
}
}
// No machine functions available - return error
return createStringError(errc::invalid_argument,
"No machine functions found in module");
}
//===----------------------------------------------------------------------===//
// Printer Passes Implementation
//===----------------------------------------------------------------------===//
char MIR2VecVocabPrinterLegacyPass::ID = 0;
INITIALIZE_PASS_BEGIN(MIR2VecVocabPrinterLegacyPass, "print-mir2vec-vocab",
"MIR2Vec Vocabulary Printer Pass", false, true)
INITIALIZE_PASS_DEPENDENCY(MIR2VecVocabLegacyAnalysis)
INITIALIZE_PASS_DEPENDENCY(MachineModuleInfoWrapperPass)
INITIALIZE_PASS_END(MIR2VecVocabPrinterLegacyPass, "print-mir2vec-vocab",
"MIR2Vec Vocabulary Printer Pass", false, true)
bool MIR2VecVocabPrinterLegacyPass::runOnMachineFunction(MachineFunction &MF) {
return false;
}
bool MIR2VecVocabPrinterLegacyPass::doFinalization(Module &M) {
auto &Analysis = getAnalysis<MIR2VecVocabLegacyAnalysis>();
auto MIR2VecVocabOrErr = Analysis.getMIR2VecVocabulary(M);
if (!MIR2VecVocabOrErr) {
OS << "MIR2Vec Vocabulary Printer: Failed to get vocabulary - "
<< toString(MIR2VecVocabOrErr.takeError()) << "\n";
return false;
}
auto &MIR2VecVocab = *MIR2VecVocabOrErr;
unsigned Pos = 0;
for (const auto &Entry : MIR2VecVocab) {
OS << "Key: " << MIR2VecVocab.getStringKey(Pos++) << ": ";
Entry.print(OS);
}
return false;
}
MachineFunctionPass *
llvm::createMIR2VecVocabPrinterLegacyPass(raw_ostream &OS) {
return new MIR2VecVocabPrinterLegacyPass(OS);
}