blob: d42749c553eea4246109b235ffdf58cd66db6c03 [file] [edit]
//===- MIR2VecTest.cpp ---------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "llvm/CodeGen/MIR2Vec.h"
#include "llvm/CodeGen/MachineBasicBlock.h"
#include "llvm/CodeGen/MachineFunction.h"
#include "llvm/CodeGen/MachineInstr.h"
#include "llvm/CodeGen/MachineModuleInfo.h"
#include "llvm/CodeGen/TargetInstrInfo.h"
#include "llvm/CodeGen/TargetSubtargetInfo.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Module.h"
#include "llvm/MC/TargetRegistry.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Target/TargetOptions.h"
#include "llvm/TargetParser/Triple.h"
#include "gtest/gtest.h"
using namespace llvm;
using namespace mir2vec;
using VocabMap = std::map<std::string, ir2vec::Embedding>;
namespace {
TEST(MIR2VecTest, RegexExtraction) {
// Test simple instruction names
EXPECT_EQ(MIRVocabulary::extractBaseOpcodeName("NOP"), "NOP");
EXPECT_EQ(MIRVocabulary::extractBaseOpcodeName("RET"), "RET");
EXPECT_EQ(MIRVocabulary::extractBaseOpcodeName("ADD16ri"), "ADD");
EXPECT_EQ(MIRVocabulary::extractBaseOpcodeName("ADD32rr"), "ADD");
EXPECT_EQ(MIRVocabulary::extractBaseOpcodeName("ADD64rm"), "ADD");
EXPECT_EQ(MIRVocabulary::extractBaseOpcodeName("MOV8ri"), "MOV");
EXPECT_EQ(MIRVocabulary::extractBaseOpcodeName("MOV32mr"), "MOV");
EXPECT_EQ(MIRVocabulary::extractBaseOpcodeName("PUSH64r"), "PUSH");
EXPECT_EQ(MIRVocabulary::extractBaseOpcodeName("POP64r"), "POP");
EXPECT_EQ(MIRVocabulary::extractBaseOpcodeName("JMP_4"), "JMP");
EXPECT_EQ(MIRVocabulary::extractBaseOpcodeName("CALL64pcrel32"), "CALL");
EXPECT_EQ(MIRVocabulary::extractBaseOpcodeName("SOME_INSTR_123"),
"SOME_INSTR");
EXPECT_EQ(MIRVocabulary::extractBaseOpcodeName("123ADD"), "ADD");
EXPECT_FALSE(MIRVocabulary::extractBaseOpcodeName("123").empty());
}
class MIR2VecVocabTestFixture : public ::testing::Test {
protected:
std::unique_ptr<LLVMContext> Ctx;
std::unique_ptr<Module> M;
std::unique_ptr<TargetMachine> TM;
const TargetInstrInfo *TII = nullptr;
const TargetRegisterInfo *TRI = nullptr;
std::unique_ptr<MachineModuleInfo> MMI;
MachineFunction *MF = nullptr;
static void SetUpTestCase() {
InitializeAllTargets();
InitializeAllTargetMCs();
}
void SetUp() override {
Triple TargetTriple("x86_64-unknown-linux-gnu");
std::string Error;
const Target *T = TargetRegistry::lookupTarget("", TargetTriple, Error);
if (!T) {
GTEST_SKIP() << "x86_64-unknown-linux-gnu target triple not available; "
"Skipping test";
return;
}
Ctx = std::make_unique<LLVMContext>();
M = std::make_unique<Module>("test", *Ctx);
M->setTargetTriple(TargetTriple);
TargetOptions Options;
TM = std::unique_ptr<TargetMachine>(
T->createTargetMachine(TargetTriple, "", "", Options, std::nullopt));
if (!TM) {
GTEST_SKIP() << "Failed to create X86 target machine; Skipping test";
return;
}
// Set the data layout to match the target machine
M->setDataLayout(TM->createDataLayout());
// Create a dummy function to get subtarget info
FunctionType *FT = FunctionType::get(Type::getVoidTy(*Ctx), false);
Function *F =
Function::Create(FT, Function::ExternalLinkage, "test", M.get());
// Create MMI and MF to get TRI and MRI
MMI = std::make_unique<MachineModuleInfo>(TM.get());
MF = &MMI->getOrCreateMachineFunction(*F);
// Get the target instruction info and register info
TII = TM->getSubtargetImpl(*F)->getInstrInfo();
TRI = TM->getSubtargetImpl(*F)->getRegisterInfo();
if (!TII || !TRI) {
GTEST_SKIP()
<< "Failed to get target instruction/register info; Skipping test";
return;
}
}
void TearDown() override {
TII = nullptr;
TRI = nullptr;
}
// Find an opcode by name
int findOpcodeByName(StringRef Name) {
for (unsigned Opcode = 1; Opcode < TII->getNumOpcodes(); ++Opcode) {
if (TII->getName(Opcode) == Name)
return Opcode;
}
return -1; // Not found
}
// Create a vocabulary with specific opcodes and embeddings
// This might cause errors in future when the validation in
// MIRVocabulary::generateStorage() enforces hard checks on the vocabulary
// entries.
Expected<MIRVocabulary> createTestVocab(
std::initializer_list<std::pair<const char *, float>> Opcodes,
std::initializer_list<std::pair<const char *, float>> CommonOperands,
std::initializer_list<std::pair<const char *, float>> PhyRegs,
std::initializer_list<std::pair<const char *, float>> VirtRegs,
unsigned Dimension = 2) {
assert(TII && TRI && MF && "Target info not initialized");
VocabMap OpcodeMap, CommonOperandMap, PhyRegMap, VirtRegMap;
for (const auto &[Name, Value] : Opcodes)
OpcodeMap[Name] = Embedding(Dimension, Value);
for (const auto &[Name, Value] : CommonOperands)
CommonOperandMap[Name] = Embedding(Dimension, Value);
for (const auto &[Name, Value] : PhyRegs)
PhyRegMap[Name] = Embedding(Dimension, Value);
for (const auto &[Name, Value] : VirtRegs)
VirtRegMap[Name] = Embedding(Dimension, Value);
// If any section is empty, create minimal maps for other vocabulary
// sections to satisfy validation
if (Opcodes.size() == 0)
OpcodeMap["NOOP"] = Embedding(Dimension, 0.0f);
if (CommonOperands.size() == 0)
CommonOperandMap["Immediate"] = Embedding(Dimension, 0.0f);
if (PhyRegs.size() == 0)
PhyRegMap["GR32"] = Embedding(Dimension, 0.0f);
if (VirtRegs.size() == 0)
VirtRegMap["GR32"] = Embedding(Dimension, 0.0f);
return MIRVocabulary::create(
std::move(OpcodeMap), std::move(CommonOperandMap), std::move(PhyRegMap),
std::move(VirtRegMap), *TII, *TRI, MF->getRegInfo());
}
};
// Parameterized test for empty vocab sections
class MIR2VecVocabEmptySectionTestFixture
: public MIR2VecVocabTestFixture,
public ::testing::WithParamInterface<int> {
protected:
void SetUp() override {
MIR2VecVocabTestFixture::SetUp();
// If base class setup was skipped (TII not initialized), skip derived setup
if (!TII)
GTEST_SKIP() << "Failed to get target instruction info in "
"the base class setup; Skipping test";
}
};
TEST_P(MIR2VecVocabEmptySectionTestFixture, EmptySectionFailsValidation) {
int EmptySection = GetParam();
VocabMap OpcodeMap, CommonOperandMap, PhyRegMap, VirtRegMap;
if (EmptySection != 0)
OpcodeMap["ADD"] = Embedding(2, 1.0f);
if (EmptySection != 1)
CommonOperandMap["Immediate"] = Embedding(2, 0.0f);
if (EmptySection != 2)
PhyRegMap["GR32"] = Embedding(2, 0.0f);
if (EmptySection != 3)
VirtRegMap["GR32"] = Embedding(2, 0.0f);
ASSERT_TRUE(TII != nullptr);
ASSERT_TRUE(TRI != nullptr);
ASSERT_TRUE(MF != nullptr);
auto VocabOrErr = MIRVocabulary::create(
std::move(OpcodeMap), std::move(CommonOperandMap), std::move(PhyRegMap),
std::move(VirtRegMap), *TII, *TRI, MF->getRegInfo());
EXPECT_FALSE(static_cast<bool>(VocabOrErr))
<< "Factory method should fail when section " << EmptySection
<< " is empty";
if (!VocabOrErr) {
auto Err = VocabOrErr.takeError();
std::string ErrorMsg = toString(std::move(Err));
EXPECT_FALSE(ErrorMsg.empty());
}
}
INSTANTIATE_TEST_SUITE_P(EmptySection, MIR2VecVocabEmptySectionTestFixture,
::testing::Values(0, 1, 2, 3));
TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) {
// Test that same base opcodes get same canonical indices
std::string BaseName1 = MIRVocabulary::extractBaseOpcodeName("ADD16ri");
std::string BaseName2 = MIRVocabulary::extractBaseOpcodeName("ADD32rr");
std::string BaseName3 = MIRVocabulary::extractBaseOpcodeName("ADD64rm");
EXPECT_EQ(BaseName1, BaseName2);
EXPECT_EQ(BaseName2, BaseName3);
// Create a MIRVocabulary instance to test the mapping
// Use a minimal MIRVocabulary to trigger canonical mapping construction
Embedding Val = Embedding(64, 1.0f);
auto TestVocabOrErr = createTestVocab({{"ADD", 1.0f}}, {}, {}, {}, 64);
ASSERT_TRUE(static_cast<bool>(TestVocabOrErr))
<< "Failed to create vocabulary: "
<< toString(TestVocabOrErr.takeError());
auto &TestVocab = *TestVocabOrErr;
unsigned Index1 = TestVocab.getCanonicalIndexForBaseName(BaseName1);
unsigned Index2 = TestVocab.getCanonicalIndexForBaseName(BaseName2);
unsigned Index3 = TestVocab.getCanonicalIndexForBaseName(BaseName3);
EXPECT_EQ(Index1, Index2);
EXPECT_EQ(Index2, Index3);
// Test that different base opcodes get different canonical indices
std::string AddBase = MIRVocabulary::extractBaseOpcodeName("ADD32rr");
std::string SubBase = MIRVocabulary::extractBaseOpcodeName("SUB32rr");
std::string MovBase = MIRVocabulary::extractBaseOpcodeName("MOV32rr");
unsigned AddIndex = TestVocab.getCanonicalIndexForBaseName(AddBase);
unsigned SubIndex = TestVocab.getCanonicalIndexForBaseName(SubBase);
unsigned MovIndex = TestVocab.getCanonicalIndexForBaseName(MovBase);
EXPECT_NE(AddIndex, SubIndex);
EXPECT_NE(SubIndex, MovIndex);
EXPECT_NE(AddIndex, MovIndex);
// Even though we only added "ADD" to the vocab, the canonical mapping
// should assign unique indices to all the base opcodes of the target
// Ideally, we would check against the exact number of unique base opcodes
// for X86, but that would make the test brittle. So we just check that
// the number is reasonably closer to the expected number (>6880) and not just
// opcodes that we added.
EXPECT_GT(TestVocab.getCanonicalSize(),
6880u); // X86 has >6880 unique base opcodes
// Check that the embeddings for opcodes not in the vocab are zero vectors
int Add32rrOpcode = findOpcodeByName("ADD32rr");
ASSERT_NE(Add32rrOpcode, -1) << "ADD32rr opcode not found";
EXPECT_TRUE(TestVocab[Add32rrOpcode].approximatelyEquals(Val));
int Sub32rrOpcode = findOpcodeByName("SUB32rr");
ASSERT_NE(Sub32rrOpcode, -1) << "SUB32rr opcode not found";
EXPECT_TRUE(
TestVocab[Sub32rrOpcode].approximatelyEquals(Embedding(64, 0.0f)));
int Mov32rrOpcode = findOpcodeByName("MOV32rr");
ASSERT_NE(Mov32rrOpcode, -1) << "MOV32rr opcode not found";
EXPECT_TRUE(
TestVocab[Mov32rrOpcode].approximatelyEquals(Embedding(64, 0.0f)));
}
// Test deterministic mapping
TEST_F(MIR2VecVocabTestFixture, DeterministicMapping) {
// Test that the same base name always maps to the same canonical index
std::string BaseName = "ADD";
// Create a MIRVocabulary instance to test deterministic mapping
// Use a minimal MIRVocabulary to trigger canonical mapping construction
auto TestVocabOrErr = createTestVocab({{"ADD", 1.0f}}, {}, {}, {}, 64);
ASSERT_TRUE(static_cast<bool>(TestVocabOrErr))
<< "Failed to create vocabulary: "
<< toString(TestVocabOrErr.takeError());
auto &TestVocab = *TestVocabOrErr;
unsigned Index1 = TestVocab.getCanonicalIndexForBaseName(BaseName);
unsigned Index2 = TestVocab.getCanonicalIndexForBaseName(BaseName);
unsigned Index3 = TestVocab.getCanonicalIndexForBaseName(BaseName);
EXPECT_EQ(Index2, Index3);
// Test across multiple runs
for (int Pos = 0; Pos < 100; ++Pos) {
unsigned Index = TestVocab.getCanonicalIndexForBaseName(BaseName);
EXPECT_EQ(Index, Index1);
}
}
// Test MIRVocabulary construction
TEST_F(MIR2VecVocabTestFixture, VocabularyConstruction) {
auto VocabOrErr =
createTestVocab({{"ADD", 1.0f}, {"SUB", 2.0f}}, {}, {}, {}, 128);
ASSERT_TRUE(static_cast<bool>(VocabOrErr))
<< "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
auto &Vocab = *VocabOrErr;
EXPECT_EQ(Vocab.getDimension(), 128u);
// Test iterator - iterates over individual embeddings
auto IT = Vocab.begin();
EXPECT_NE(IT, Vocab.end());
// Check first embedding exists and has correct dimension
EXPECT_EQ((*IT).size(), 128u);
size_t Count = 0;
for (auto IT = Vocab.begin(); IT != Vocab.end(); ++IT) {
EXPECT_EQ((*IT).size(), 128u);
++Count;
}
EXPECT_GT(Count, 0u);
}
// Fixture for embedding related tests
class MIR2VecEmbeddingTestFixture : public MIR2VecVocabTestFixture {
protected:
void SetUp() override {
MIR2VecVocabTestFixture::SetUp();
// If base class setup was skipped (TII not initialized), skip derived setup
if (!TII)
GTEST_SKIP() << "Failed to get target instruction info in "
"the base class setup; Skipping test";
}
void TearDown() override { MIR2VecVocabTestFixture::TearDown(); }
// Create a machine instruction
MachineInstr *createMachineInstr(MachineBasicBlock &MBB, unsigned Opcode) {
const MCInstrDesc &Desc = TII->get(Opcode);
// Create instruction - operands don't affect opcode-based embeddings
MachineInstr *MI = BuildMI(MBB, MBB.end(), DebugLoc(), Desc);
return MI;
}
MachineInstr *createMachineInstr(MachineBasicBlock &MBB,
const char *OpcodeName) {
int Opcode = findOpcodeByName(OpcodeName);
if (Opcode == -1)
return nullptr;
return createMachineInstr(MBB, Opcode);
}
void createMachineInstrs(MachineBasicBlock &MBB,
std::initializer_list<const char *> Opcodes) {
for (const char *OpcodeName : Opcodes) {
MachineInstr *MI = createMachineInstr(MBB, OpcodeName);
ASSERT_TRUE(MI != nullptr);
}
}
};
// Test factory method for creating embedder
TEST_F(MIR2VecEmbeddingTestFixture, CreateSymbolicEmbedder) {
auto VocabOrErr =
MIRVocabulary::createDummyVocabForTest(*TII, *TRI, MF->getRegInfo(), 1);
ASSERT_TRUE(static_cast<bool>(VocabOrErr))
<< "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
auto &V = *VocabOrErr;
auto Emb = MIREmbedder::create(MIR2VecKind::Symbolic, *MF, V);
EXPECT_NE(Emb, nullptr);
}
TEST_F(MIR2VecEmbeddingTestFixture, CreateInvalidMode) {
auto VocabOrErr =
MIRVocabulary::createDummyVocabForTest(*TII, *TRI, MF->getRegInfo(), 1);
ASSERT_TRUE(static_cast<bool>(VocabOrErr))
<< "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
auto &V = *VocabOrErr;
auto Result = MIREmbedder::create(static_cast<MIR2VecKind>(-1), *MF, V);
EXPECT_FALSE(static_cast<bool>(Result));
}
// Test SymbolicMIREmbedder with simple target opcodes
TEST_F(MIR2VecEmbeddingTestFixture, TestSymbolicEmbedder) {
// Create a test vocabulary with specific values
auto VocabOrErr = createTestVocab(
{
{"NOOP", 1.0f}, // [1.0, 1.0, 1.0, 1.0]
{"RET", 2.0f}, // [2.0, 2.0, 2.0, 2.0]
{"TRAP", 3.0f} // [3.0, 3.0, 3.0, 3.0]
},
{}, {}, {}, 4);
ASSERT_TRUE(static_cast<bool>(VocabOrErr))
<< "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
auto &Vocab = *VocabOrErr;
// Create a basic block using fixture's MF
MachineBasicBlock *MBB = MF->CreateMachineBasicBlock();
MF->push_back(MBB);
// Use real X86 opcodes that should exist and not be pseudo
auto NoopInst = createMachineInstr(*MBB, "NOOP");
ASSERT_TRUE(NoopInst != nullptr);
auto RetInst = createMachineInstr(*MBB, "RET64");
ASSERT_TRUE(RetInst != nullptr);
auto TrapInst = createMachineInstr(*MBB, "TRAP");
ASSERT_TRUE(TrapInst != nullptr);
// Verify these are not pseudo instructions
ASSERT_FALSE(NoopInst->isPseudo()) << "NOOP is marked as pseudo instruction";
ASSERT_FALSE(RetInst->isPseudo()) << "RET is marked as pseudo instruction";
ASSERT_FALSE(TrapInst->isPseudo()) << "TRAP is marked as pseudo instruction";
// Create embedder
auto Embedder = SymbolicMIREmbedder::create(*MF, Vocab);
ASSERT_TRUE(Embedder != nullptr);
// Test instruction embeddings
auto NoopEmb = Embedder->getMInstVector(*NoopInst);
auto RetEmb = Embedder->getMInstVector(*RetInst);
auto TrapEmb = Embedder->getMInstVector(*TrapInst);
// Verify embeddings match expected values (accounting for weight scaling)
float ExpectedWeight = mir2vec::OpcWeight; // Global weight from command line
EXPECT_TRUE(NoopEmb.approximatelyEquals(Embedding(4, 1.0f * ExpectedWeight)));
EXPECT_TRUE(RetEmb.approximatelyEquals(Embedding(4, 2.0f * ExpectedWeight)));
EXPECT_TRUE(TrapEmb.approximatelyEquals(Embedding(4, 3.0f * ExpectedWeight)));
// Test basic block embedding (should be sum of instruction embeddings)
auto MBBVector = Embedder->getMBBVector(*MBB);
// Expected BB vector: NOOP + RET + TRAP = [1+2+3, 1+2+3, 1+2+3, 1+2+3] *
// weight = [6, 6, 6, 6] * weight
Embedding ExpectedMBBVector(4, 6.0f * ExpectedWeight);
EXPECT_TRUE(MBBVector.approximatelyEquals(ExpectedMBBVector));
// Test function embedding (should equal MBB embedding since we have one MBB)
auto MFuncVector = Embedder->getMFunctionVector();
EXPECT_TRUE(MFuncVector.approximatelyEquals(ExpectedMBBVector));
}
// Test embedder with multiple basic blocks
TEST_F(MIR2VecEmbeddingTestFixture, MultipleBasicBlocks) {
// Create a test vocabulary
auto VocabOrErr =
createTestVocab({{"NOOP", 1.0f}, {"TRAP", 2.0f}}, {}, {}, {});
ASSERT_TRUE(static_cast<bool>(VocabOrErr))
<< "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
auto &Vocab = *VocabOrErr;
// Create two basic blocks using fixture's MF
MachineBasicBlock *MBB1 = MF->CreateMachineBasicBlock();
MachineBasicBlock *MBB2 = MF->CreateMachineBasicBlock();
MF->push_back(MBB1);
MF->push_back(MBB2);
createMachineInstrs(*MBB1, {"NOOP", "NOOP"});
createMachineInstr(*MBB2, "TRAP");
// Create embedder
auto Embedder = SymbolicMIREmbedder::create(*MF, Vocab);
ASSERT_TRUE(Embedder != nullptr);
// Test basic block embeddings
auto MBB1Vector = Embedder->getMBBVector(*MBB1);
auto MBB2Vector = Embedder->getMBBVector(*MBB2);
float ExpectedWeight = mir2vec::OpcWeight;
// BB1: NOOP + NOOP = 2 * ([1, 1] * weight)
Embedding ExpectedMBB1Vector(2, 2.0f * ExpectedWeight);
EXPECT_TRUE(MBB1Vector.approximatelyEquals(ExpectedMBB1Vector));
// BB2: TRAP = [2, 2] * weight
Embedding ExpectedMBB2Vector(2, 2.0f * ExpectedWeight);
EXPECT_TRUE(MBB2Vector.approximatelyEquals(ExpectedMBB2Vector));
// Function embedding: BB1 + BB2 = [2+2, 2+2] * weight = [4, 4] * weight
// Function embedding should be just the first BB embedding as the second BB
// is unreachable
auto MFuncVector = Embedder->getMFunctionVector();
EXPECT_TRUE(MFuncVector.approximatelyEquals(ExpectedMBB1Vector));
// Add a branch from BB1 to BB2 to make both reachable; now function embedding
// should be MBB1 + MBB2
MBB1->addSuccessor(MBB2);
auto NewMFuncVector = Embedder->getMFunctionVector(); // Recompute embeddings
Embedding ExpectedFuncVector = MBB1Vector + MBB2Vector;
EXPECT_TRUE(NewMFuncVector.approximatelyEquals(ExpectedFuncVector));
}
// Test embedder with empty basic block
TEST_F(MIR2VecEmbeddingTestFixture, EmptyBasicBlock) {
// Create an empty basic block
MachineBasicBlock *MBB = MF->CreateMachineBasicBlock();
MF->push_back(MBB);
// Create embedder
auto VocabOrErr =
MIRVocabulary::createDummyVocabForTest(*TII, *TRI, MF->getRegInfo(), 2);
ASSERT_TRUE(static_cast<bool>(VocabOrErr))
<< "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
auto &V = *VocabOrErr;
auto Embedder = SymbolicMIREmbedder::create(*MF, V);
ASSERT_TRUE(Embedder != nullptr);
// Test that empty BB has zero embedding
auto MBBVector = Embedder->getMBBVector(*MBB);
Embedding ExpectedBBVector(2, 0.0f);
EXPECT_TRUE(MBBVector.approximatelyEquals(ExpectedBBVector));
// Function embedding should also be zero
auto MFuncVector = Embedder->getMFunctionVector();
EXPECT_TRUE(MFuncVector.approximatelyEquals(ExpectedBBVector));
}
// Test embedder with opcodes not in vocabulary
TEST_F(MIR2VecEmbeddingTestFixture, UnknownOpcodes) {
// Create a test vocabulary with limited entries
// SUB is intentionally not included
auto VocabOrErr = createTestVocab({{"ADD", 1.0f}}, {}, {}, {});
ASSERT_TRUE(static_cast<bool>(VocabOrErr))
<< "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
auto &Vocab = *VocabOrErr;
// Create a basic block
MachineBasicBlock *MBB = MF->CreateMachineBasicBlock();
MF->push_back(MBB);
// Find opcodes
int AddOpcode = findOpcodeByName("ADD32rr");
int SubOpcode = findOpcodeByName("SUB32rr");
ASSERT_NE(AddOpcode, -1) << "ADD32rr opcode not found";
ASSERT_NE(SubOpcode, -1) << "SUB32rr opcode not found";
// Create instructions
MachineInstr *AddInstr = createMachineInstr(*MBB, AddOpcode);
MachineInstr *SubInstr = createMachineInstr(*MBB, SubOpcode);
// Create embedder
auto Embedder = SymbolicMIREmbedder::create(*MF, Vocab);
ASSERT_TRUE(Embedder != nullptr);
// Test instruction embeddings
auto AddVector = Embedder->getMInstVector(*AddInstr);
auto SubVector = Embedder->getMInstVector(*SubInstr);
float ExpectedWeight = mir2vec::OpcWeight;
// ADD should have the embedding from vocabulary
EXPECT_TRUE(
AddVector.approximatelyEquals(Embedding(2, 1.0f * ExpectedWeight)));
// SUB should have zero embedding (not in vocabulary)
EXPECT_TRUE(SubVector.approximatelyEquals(Embedding(2, 0.0f)));
// Basic block embedding should be ADD + SUB = [1.0, 1.0] * weight + [0.0,
// 0.0] = [1.0, 1.0] * weight
const auto &MBBVector = Embedder->getMBBVector(*MBB);
Embedding ExpectedBBVector(2, 1.0f * ExpectedWeight);
EXPECT_TRUE(MBBVector.approximatelyEquals(ExpectedBBVector));
}
// Test vocabulary string key generation
TEST_F(MIR2VecEmbeddingTestFixture, VocabularyStringKeys) {
auto VocabOrErr =
createTestVocab({{"ADD", 1.0f}, {"SUB", 2.0f}}, {}, {}, {}, 2);
ASSERT_TRUE(static_cast<bool>(VocabOrErr))
<< "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
auto &Vocab = *VocabOrErr;
// Test that we can get string keys for all positions
for (size_t Pos = 0; Pos < Vocab.getCanonicalSize(); ++Pos) {
std::string Key = Vocab.getStringKey(Pos);
EXPECT_FALSE(Key.empty()) << "Empty key at position " << Pos;
}
// Test specific known positions if we can identify them
unsigned AddIndex = Vocab.getCanonicalIndexForBaseName("ADD");
std::string AddKey = Vocab.getStringKey(AddIndex);
EXPECT_EQ(AddKey, "ADD");
unsigned SubIndex = Vocab.getCanonicalIndexForBaseName("SUB");
std::string SubKey = Vocab.getStringKey(SubIndex);
EXPECT_EQ(SubKey, "SUB");
unsigned ImmIndex = Vocab.getCanonicalIndexForOperandName("Immediate");
std::string ImmKey = Vocab.getStringKey(ImmIndex);
EXPECT_EQ(ImmKey, "Immediate");
unsigned PhyRegIndex = Vocab.getCanonicalIndexForRegisterClass("GR32", true);
std::string PhyRegKey = Vocab.getStringKey(PhyRegIndex);
EXPECT_EQ(PhyRegKey, "PhyReg_GR32");
unsigned VirtRegIndex =
Vocab.getCanonicalIndexForRegisterClass("GR32", false);
std::string VirtRegKey = Vocab.getStringKey(VirtRegIndex);
EXPECT_EQ(VirtRegKey, "VirtReg_GR32");
}
// Test vocabulary dimension consistency
TEST_F(MIR2VecEmbeddingTestFixture, DimensionConsistency) {
auto VocabOrErr = createTestVocab({{"TEST", 1.0f}}, {}, {}, {}, 5);
ASSERT_TRUE(static_cast<bool>(VocabOrErr))
<< "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
auto &Vocab = *VocabOrErr;
EXPECT_EQ(Vocab.getDimension(), 5u);
// All embeddings should have the same dimension
for (auto IT = Vocab.begin(); IT != Vocab.end(); ++IT)
EXPECT_EQ((*IT).size(), 5u);
}
// Test invalid register handling through machine instruction creation
TEST_F(MIR2VecEmbeddingTestFixture, InvalidRegisterHandling) {
float MOVValue = 1.5f;
float ImmValue = 0.5f;
float PhyRegValue = 0.2f;
auto VocabOrErr = createTestVocab(
{{"MOV", MOVValue}}, {{"Immediate", ImmValue}},
{{"GR8_ABCD_H", PhyRegValue}, {"GR8_ABCD_L", PhyRegValue + 0.1f}}, {}, 3);
ASSERT_TRUE(static_cast<bool>(VocabOrErr))
<< "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
auto &Vocab = *VocabOrErr;
MachineBasicBlock *MBB = MF->CreateMachineBasicBlock();
MF->push_back(MBB);
// Create a MOV instruction with actual operands including potential $noreg
// This tests the actual scenario where invalid registers are encountered
auto MovOpcode = findOpcodeByName("MOV32mr");
ASSERT_NE(MovOpcode, -1) << "MOV32mr opcode not found";
const MCInstrDesc &Desc = TII->get(MovOpcode);
// Use available physical registers from the target
unsigned BaseReg =
TRI->getNumRegs() > 1 ? 1 : 0; // First available physical register
unsigned ValueReg = TRI->getNumRegs() > 2 ? 2 : BaseReg;
// MOV32mr typically has: base, scale, index, displacement, segment, value
// Use the MachineInstrBuilder API properly
auto MovInst = BuildMI(*MBB, MBB->end(), DebugLoc(), Desc)
.addReg(BaseReg) // base
.addImm(1) // scale
.addReg(0) // index ($noreg)
.addImm(-4) // displacement
.addReg(0) // segment ($noreg)
.addReg(ValueReg); // value
auto Embedder = SymbolicMIREmbedder::create(*MF, Vocab);
ASSERT_TRUE(Embedder != nullptr);
// This should not crash even if the instruction has $noreg operands
auto InstEmb = Embedder->getMInstVector(*MovInst);
EXPECT_EQ(InstEmb.size(), 3u);
// Test the expected embedding value
Embedding ExpectedOpcodeContribution(3, MOVValue * mir2vec::OpcWeight);
auto ExpectedOperandContribution =
Embedding(3, PhyRegValue * mir2vec::RegOperandWeight) // Base
+ Embedding(3, ImmValue * mir2vec::CommonOperandWeight) // Scale
+ Embedding(3, 0.0f) // noreg
+ Embedding(3, ImmValue * mir2vec::CommonOperandWeight) // displacement
+ Embedding(3, 0.0f) // noreg
+ Embedding(3, (PhyRegValue + 0.1f) * mir2vec::RegOperandWeight); // Value
auto ExpectedEmb = ExpectedOpcodeContribution + ExpectedOperandContribution;
EXPECT_TRUE(InstEmb.approximatelyEquals(ExpectedEmb))
<< "MOV instruction embedding should match expected embedding";
}
// Test handling of both physical and virtual registers in an instruction
TEST_F(MIR2VecEmbeddingTestFixture, PhysicalAndVirtualRegisterHandling) {
float MOVValue = 2.0f;
float ImmValue = 0.7f;
float PhyRegValue = 0.3f;
float VirtRegValue = 0.9f;
// Find GR32 register class
const TargetRegisterClass *GR32RC = nullptr;
for (unsigned i = 0; i < TRI->getNumRegClasses(); ++i) {
const TargetRegisterClass *RC = TRI->getRegClass(i);
if (std::string(TRI->getRegClassName(RC)) == "GR32") {
GR32RC = RC;
break;
}
}
ASSERT_TRUE(GR32RC != nullptr && GR32RC->isAllocatable())
<< "No allocatable GR32 register class found";
// Get first available physical register from GR32
unsigned PhyReg = *GR32RC->begin();
// Create a virtual register of class GR32
unsigned VirtReg = MF->getRegInfo().createVirtualRegister(GR32RC);
// Create vocabulary with register class based keys
auto VocabOrErr =
createTestVocab({{"MOV", MOVValue}}, {{"Immediate", ImmValue}},
{{"GR32_AD", PhyRegValue}}, // GR32_AD is the minimal key
{{"GR32", VirtRegValue}}, 4);
ASSERT_TRUE(static_cast<bool>(VocabOrErr))
<< "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
auto &Vocab = *VocabOrErr;
MachineBasicBlock *MBB = MF->CreateMachineBasicBlock();
MF->push_back(MBB);
// Create a MOV32rr instruction: MOV32rr dst, src
auto MovOpcode = findOpcodeByName("MOV32rr");
ASSERT_NE(MovOpcode, -1) << "MOV32rr opcode not found";
const MCInstrDesc &Desc = TII->get(MovOpcode);
// MOV32rr: dst (physical), src (virtual)
auto MovInst = BuildMI(*MBB, MBB->end(), DebugLoc(), Desc)
.addReg(PhyReg) // physical register destination
.addReg(VirtReg); // virtual register source
// Create embedder with virtual register support
auto Embedder = SymbolicMIREmbedder::create(*MF, Vocab);
ASSERT_TRUE(Embedder != nullptr);
// This should not crash and should produce a valid embedding
auto InstEmb = Embedder->getMInstVector(*MovInst);
EXPECT_EQ(InstEmb.size(), 4u);
// Test the expected embedding value
Embedding ExpectedOpcodeContribution(4, MOVValue * mir2vec::OpcWeight);
auto ExpectedOperandContribution =
Embedding(4, PhyRegValue * mir2vec::RegOperandWeight) // dst (physical)
+ Embedding(4, VirtRegValue * mir2vec::RegOperandWeight); // src (virtual)
auto ExpectedEmb = ExpectedOpcodeContribution + ExpectedOperandContribution;
EXPECT_TRUE(InstEmb.approximatelyEquals(ExpectedEmb))
<< "MOV32rr instruction embedding should match expected embedding";
}
// Test precise embedding calculation with known operands
TEST_F(MIR2VecEmbeddingTestFixture, EmbeddingCalculation) {
auto VocabOrErr = createTestVocab({{"NOOP", 2.0f}}, {}, {}, {}, 2);
ASSERT_TRUE(static_cast<bool>(VocabOrErr))
<< "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
auto &Vocab = *VocabOrErr;
MachineBasicBlock *MBB = MF->CreateMachineBasicBlock();
MF->push_back(MBB);
// Create a simple NOOP instruction (no operands)
auto NoopInst = createMachineInstr(*MBB, "NOOP");
ASSERT_TRUE(NoopInst != nullptr);
auto Embedder = SymbolicMIREmbedder::create(*MF, Vocab);
ASSERT_TRUE(Embedder != nullptr);
// Get the instruction embedding
auto InstEmb = Embedder->getMInstVector(*NoopInst);
EXPECT_EQ(InstEmb.size(), 2u);
// For NOOP with no operands, the embedding should be exactly the opcode
// embedding
float ExpectedWeight = mir2vec::OpcWeight;
Embedding ExpectedEmb(2, 2.0f * ExpectedWeight);
EXPECT_TRUE(InstEmb.approximatelyEquals(ExpectedEmb))
<< "NOOP instruction embedding should match opcode embedding";
// Verify individual components
EXPECT_FLOAT_EQ(InstEmb[0], 2.0f * ExpectedWeight);
EXPECT_FLOAT_EQ(InstEmb[1], 2.0f * ExpectedWeight);
}
} // namespace