blob: 11222b4d02fa366633acfa7b45eb1e7029bb4882 [file] [log] [blame]
//===- MIR2VecTest.cpp ---------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "llvm/CodeGen/MIR2Vec.h"
#include "llvm/CodeGen/MachineBasicBlock.h"
#include "llvm/CodeGen/MachineFunction.h"
#include "llvm/CodeGen/MachineInstr.h"
#include "llvm/CodeGen/MachineModuleInfo.h"
#include "llvm/CodeGen/TargetInstrInfo.h"
#include "llvm/CodeGen/TargetSubtargetInfo.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Module.h"
#include "llvm/MC/TargetRegistry.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Target/TargetOptions.h"
#include "llvm/TargetParser/Triple.h"
#include "gtest/gtest.h"
using namespace llvm;
using namespace mir2vec;
using VocabMap = std::map<std::string, ir2vec::Embedding>;
namespace {
TEST(MIR2VecTest, RegexExtraction) {
// Test simple instruction names
EXPECT_EQ(MIRVocabulary::extractBaseOpcodeName("NOP"), "NOP");
EXPECT_EQ(MIRVocabulary::extractBaseOpcodeName("RET"), "RET");
EXPECT_EQ(MIRVocabulary::extractBaseOpcodeName("ADD16ri"), "ADD");
EXPECT_EQ(MIRVocabulary::extractBaseOpcodeName("ADD32rr"), "ADD");
EXPECT_EQ(MIRVocabulary::extractBaseOpcodeName("ADD64rm"), "ADD");
EXPECT_EQ(MIRVocabulary::extractBaseOpcodeName("MOV8ri"), "MOV");
EXPECT_EQ(MIRVocabulary::extractBaseOpcodeName("MOV32mr"), "MOV");
EXPECT_EQ(MIRVocabulary::extractBaseOpcodeName("PUSH64r"), "PUSH");
EXPECT_EQ(MIRVocabulary::extractBaseOpcodeName("POP64r"), "POP");
EXPECT_EQ(MIRVocabulary::extractBaseOpcodeName("JMP_4"), "JMP");
EXPECT_EQ(MIRVocabulary::extractBaseOpcodeName("CALL64pcrel32"), "CALL");
EXPECT_EQ(MIRVocabulary::extractBaseOpcodeName("SOME_INSTR_123"),
"SOME_INSTR");
EXPECT_EQ(MIRVocabulary::extractBaseOpcodeName("123ADD"), "ADD");
EXPECT_FALSE(MIRVocabulary::extractBaseOpcodeName("123").empty());
}
class MIR2VecVocabTestFixture : public ::testing::Test {
protected:
std::unique_ptr<LLVMContext> Ctx;
std::unique_ptr<Module> M;
std::unique_ptr<TargetMachine> TM;
const TargetInstrInfo *TII = nullptr;
static void SetUpTestCase() {
InitializeAllTargets();
InitializeAllTargetMCs();
}
void SetUp() override {
Triple TargetTriple("x86_64-unknown-linux-gnu");
std::string Error;
const Target *T = TargetRegistry::lookupTarget("", TargetTriple, Error);
if (!T) {
GTEST_SKIP() << "x86_64-unknown-linux-gnu target triple not available; "
"Skipping test";
return;
}
Ctx = std::make_unique<LLVMContext>();
M = std::make_unique<Module>("test", *Ctx);
M->setTargetTriple(TargetTriple);
TargetOptions Options;
TM = std::unique_ptr<TargetMachine>(
T->createTargetMachine(TargetTriple, "", "", Options, std::nullopt));
if (!TM) {
GTEST_SKIP() << "Failed to create X86 target machine; Skipping test";
return;
}
// Create a dummy function to get subtarget info
FunctionType *FT = FunctionType::get(Type::getVoidTy(*Ctx), false);
Function *F =
Function::Create(FT, Function::ExternalLinkage, "test", M.get());
// Get the target instruction info
TII = TM->getSubtargetImpl(*F)->getInstrInfo();
if (!TII) {
GTEST_SKIP() << "Failed to get target instruction info; Skipping test";
return;
}
}
void TearDown() override { TII = nullptr; }
};
// Function to find an opcode by name
static int findOpcodeByName(const TargetInstrInfo *TII, StringRef Name) {
for (unsigned Opcode = 1; Opcode < TII->getNumOpcodes(); ++Opcode) {
if (TII->getName(Opcode) == Name)
return Opcode;
}
return -1; // Not found
}
TEST_F(MIR2VecVocabTestFixture, CanonicalOpcodeMappingTest) {
// Test that same base opcodes get same canonical indices
std::string BaseName1 = MIRVocabulary::extractBaseOpcodeName("ADD16ri");
std::string BaseName2 = MIRVocabulary::extractBaseOpcodeName("ADD32rr");
std::string BaseName3 = MIRVocabulary::extractBaseOpcodeName("ADD64rm");
EXPECT_EQ(BaseName1, BaseName2);
EXPECT_EQ(BaseName2, BaseName3);
// Create a MIRVocabulary instance to test the mapping
// Use a minimal MIRVocabulary to trigger canonical mapping construction
VocabMap VMap;
Embedding Val = Embedding(64, 1.0f);
VMap["ADD"] = Val;
auto TestVocabOrErr = MIRVocabulary::create(std::move(VMap), *TII);
ASSERT_TRUE(static_cast<bool>(TestVocabOrErr))
<< "Failed to create vocabulary: "
<< toString(TestVocabOrErr.takeError());
auto &TestVocab = *TestVocabOrErr;
unsigned Index1 = TestVocab.getCanonicalIndexForBaseName(BaseName1);
unsigned Index2 = TestVocab.getCanonicalIndexForBaseName(BaseName2);
unsigned Index3 = TestVocab.getCanonicalIndexForBaseName(BaseName3);
EXPECT_EQ(Index1, Index2);
EXPECT_EQ(Index2, Index3);
// Test that different base opcodes get different canonical indices
std::string AddBase = MIRVocabulary::extractBaseOpcodeName("ADD32rr");
std::string SubBase = MIRVocabulary::extractBaseOpcodeName("SUB32rr");
std::string MovBase = MIRVocabulary::extractBaseOpcodeName("MOV32rr");
unsigned AddIndex = TestVocab.getCanonicalIndexForBaseName(AddBase);
unsigned SubIndex = TestVocab.getCanonicalIndexForBaseName(SubBase);
unsigned MovIndex = TestVocab.getCanonicalIndexForBaseName(MovBase);
EXPECT_NE(AddIndex, SubIndex);
EXPECT_NE(SubIndex, MovIndex);
EXPECT_NE(AddIndex, MovIndex);
// Even though we only added "ADD" to the vocab, the canonical mapping
// should assign unique indices to all the base opcodes of the target
// Ideally, we would check against the exact number of unique base opcodes
// for X86, but that would make the test brittle. So we just check that
// the number is reasonably closer to the expected number (>6880) and not just
// opcodes that we added.
EXPECT_GT(TestVocab.getCanonicalSize(),
6880u); // X86 has >6880 unique base opcodes
// Check that the embeddings for opcodes not in the vocab are zero vectors
int Add32rrOpcode = findOpcodeByName(TII, "ADD32rr");
ASSERT_NE(Add32rrOpcode, -1) << "ADD32rr opcode not found";
EXPECT_TRUE(TestVocab[Add32rrOpcode].approximatelyEquals(Val));
int Sub32rrOpcode = findOpcodeByName(TII, "SUB32rr");
ASSERT_NE(Sub32rrOpcode, -1) << "SUB32rr opcode not found";
EXPECT_TRUE(
TestVocab[Sub32rrOpcode].approximatelyEquals(Embedding(64, 0.0f)));
int Mov32rrOpcode = findOpcodeByName(TII, "MOV32rr");
ASSERT_NE(Mov32rrOpcode, -1) << "MOV32rr opcode not found";
EXPECT_TRUE(
TestVocab[Mov32rrOpcode].approximatelyEquals(Embedding(64, 0.0f)));
}
// Test deterministic mapping
TEST_F(MIR2VecVocabTestFixture, DeterministicMapping) {
// Test that the same base name always maps to the same canonical index
std::string BaseName = "ADD";
// Create a MIRVocabulary instance to test deterministic mapping
// Use a minimal MIRVocabulary to trigger canonical mapping construction
VocabMap VMap;
VMap["ADD"] = Embedding(64, 1.0f);
auto TestVocabOrErr = MIRVocabulary::create(std::move(VMap), *TII);
ASSERT_TRUE(static_cast<bool>(TestVocabOrErr))
<< "Failed to create vocabulary: "
<< toString(TestVocabOrErr.takeError());
auto &TestVocab = *TestVocabOrErr;
unsigned Index1 = TestVocab.getCanonicalIndexForBaseName(BaseName);
unsigned Index2 = TestVocab.getCanonicalIndexForBaseName(BaseName);
unsigned Index3 = TestVocab.getCanonicalIndexForBaseName(BaseName);
EXPECT_EQ(Index1, Index2);
EXPECT_EQ(Index2, Index3);
// Test across multiple runs
for (int Pos = 0; Pos < 100; ++Pos) {
unsigned Index = TestVocab.getCanonicalIndexForBaseName(BaseName);
EXPECT_EQ(Index, Index1);
}
}
// Test MIRVocabulary construction
TEST_F(MIR2VecVocabTestFixture, VocabularyConstruction) {
VocabMap VMap;
VMap["ADD"] = Embedding(128, 1.0f); // Dimension 128, all values 1.0
VMap["SUB"] = Embedding(128, 2.0f); // Dimension 128, all values 2.0
auto VocabOrErr = MIRVocabulary::create(std::move(VMap), *TII);
ASSERT_TRUE(static_cast<bool>(VocabOrErr))
<< "Failed to create vocabulary: " << toString(VocabOrErr.takeError());
auto &Vocab = *VocabOrErr;
EXPECT_EQ(Vocab.getDimension(), 128u);
// Test iterator - iterates over individual embeddings
auto IT = Vocab.begin();
EXPECT_NE(IT, Vocab.end());
// Check first embedding exists and has correct dimension
EXPECT_EQ((*IT).size(), 128u);
size_t Count = 0;
for (auto IT = Vocab.begin(); IT != Vocab.end(); ++IT) {
EXPECT_EQ((*IT).size(), 128u);
++Count;
}
EXPECT_GT(Count, 0u);
}
// Test factory method with empty vocabulary
TEST_F(MIR2VecVocabTestFixture, EmptyVocabularyCreation) {
VocabMap EmptyVMap;
auto VocabOrErr = MIRVocabulary::create(std::move(EmptyVMap), *TII);
EXPECT_FALSE(static_cast<bool>(VocabOrErr))
<< "Factory method should fail with empty vocabulary";
// Consume the error
if (!VocabOrErr) {
auto Err = VocabOrErr.takeError();
std::string ErrorMsg = toString(std::move(Err));
EXPECT_FALSE(ErrorMsg.empty());
}
}
} // namespace