blob: 9e47b2cd8bedd5ba5e688b653d4766daf28a250a [file] [log] [blame]
//===- IR2VecTest.cpp - Unit tests for IR2Vec -----------------------------==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "llvm/Analysis/IR2Vec.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Type.h"
#include "llvm/Support/Error.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include <map>
#include <vector>
using namespace llvm;
using namespace ir2vec;
using namespace ::testing;
namespace {
class TestableEmbedder : public Embedder {
public:
TestableEmbedder(const Function &F, const Vocab &V) : Embedder(F, V) {}
void computeEmbeddings() const override {}
void computeEmbeddings(const BasicBlock &BB) const override {}
using Embedder::lookupVocab;
static void addVectors(Embedding &Dst, const Embedding &Src) {
Embedder::addVectors(Dst, Src);
}
static void addScaledVector(Embedding &Dst, const Embedding &Src,
float Factor) {
Embedder::addScaledVector(Dst, Src, Factor);
}
};
TEST(IR2VecTest, CreateSymbolicEmbedder) {
Vocab V = {{"foo", {1.0, 2.0}}};
LLVMContext Ctx;
Module M("M", Ctx);
FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false);
Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
EXPECT_TRUE(static_cast<bool>(Result));
auto *Emb = Result->get();
EXPECT_NE(Emb, nullptr);
}
TEST(IR2VecTest, CreateInvalidMode) {
Vocab V = {{"foo", {1.0, 2.0}}};
LLVMContext Ctx;
Module M("M", Ctx);
FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false);
Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
// static_cast an invalid int to IR2VecKind
auto Result = Embedder::create(static_cast<IR2VecKind>(-1), *F, V);
EXPECT_FALSE(static_cast<bool>(Result));
std::string ErrMsg;
llvm::handleAllErrors(
Result.takeError(),
[&](const llvm::ErrorInfoBase &EIB) { ErrMsg = EIB.message(); });
EXPECT_NE(ErrMsg.find("Unknown IR2VecKind"), std::string::npos);
}
TEST(IR2VecTest, AddVectors) {
Embedding E1 = {1.0, 2.0, 3.0};
Embedding E2 = {0.5, 1.5, -1.0};
TestableEmbedder::addVectors(E1, E2);
EXPECT_THAT(E1, ElementsAre(1.5, 3.5, 2.0));
// Check that E2 is unchanged
EXPECT_THAT(E2, ElementsAre(0.5, 1.5, -1.0));
}
TEST(IR2VecTest, AddScaledVector) {
Embedding E1 = {1.0, 2.0, 3.0};
Embedding E2 = {2.0, 0.5, -1.0};
TestableEmbedder::addScaledVector(E1, E2, 0.5f);
EXPECT_THAT(E1, ElementsAre(2.0, 2.25, 2.5));
// Check that E2 is unchanged
EXPECT_THAT(E2, ElementsAre(2.0, 0.5, -1.0));
}
#if GTEST_HAS_DEATH_TEST
#ifndef NDEBUG
TEST(IR2VecTest, MismatchedDimensionsAddVectors) {
Embedding E1 = {1.0, 2.0};
Embedding E2 = {1.0};
EXPECT_DEATH(TestableEmbedder::addVectors(E1, E2),
"Vectors must have the same dimension");
}
TEST(IR2VecTest, MismatchedDimensionsAddScaledVector) {
Embedding E1 = {1.0, 2.0};
Embedding E2 = {1.0};
EXPECT_DEATH(TestableEmbedder::addScaledVector(E1, E2, 1.0f),
"Vectors must have the same dimension");
}
#endif // NDEBUG
#endif // GTEST_HAS_DEATH_TEST
TEST(IR2VecTest, LookupVocab) {
Vocab V = {{"foo", {1.0, 2.0}}, {"bar", {3.0, 4.0}}};
LLVMContext Ctx;
Module M("M", Ctx);
FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false);
Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
TestableEmbedder E(*F, V);
auto V_foo = E.lookupVocab("foo");
EXPECT_EQ(V_foo.size(), 2u);
EXPECT_THAT(V_foo, ElementsAre(1.0, 2.0));
auto V_missing = E.lookupVocab("missing");
EXPECT_EQ(V_missing.size(), 2u);
EXPECT_THAT(V_missing, ElementsAre(0.0, 0.0));
}
TEST(IR2VecTest, ZeroDimensionEmbedding) {
Embedding E1;
Embedding E2;
// Should be no-op, but not crash
TestableEmbedder::addVectors(E1, E2);
TestableEmbedder::addScaledVector(E1, E2, 1.0f);
EXPECT_TRUE(E1.empty());
}
TEST(IR2VecTest, IR2VecVocabResultValidity) {
// Default constructed is invalid
IR2VecVocabResult invalidResult;
EXPECT_FALSE(invalidResult.isValid());
#if GTEST_HAS_DEATH_TEST
#ifndef NDEBUG
EXPECT_DEATH(invalidResult.getVocabulary(), "IR2Vec Vocabulary is invalid");
EXPECT_DEATH(invalidResult.getDimension(), "IR2Vec Vocabulary is invalid");
#endif // NDEBUG
#endif // GTEST_HAS_DEATH_TEST
// Valid vocab
Vocab V = {{"foo", {1.0, 2.0}}, {"bar", {3.0, 4.0}}};
IR2VecVocabResult validResult(std::move(V));
EXPECT_TRUE(validResult.isValid());
EXPECT_EQ(validResult.getDimension(), 2u);
}
// Helper to create a minimal function and embedder for getter tests
struct GetterTestEnv {
Vocab V = {};
LLVMContext Ctx;
std::unique_ptr<Module> M = nullptr;
Function *F = nullptr;
BasicBlock *BB = nullptr;
Instruction *Add = nullptr;
Instruction *Ret = nullptr;
std::unique_ptr<Embedder> Emb = nullptr;
GetterTestEnv() {
V = {{"add", {1.0, 2.0}},
{"integerTy", {0.5, 0.5}},
{"constant", {0.2, 0.3}},
{"variable", {0.0, 0.0}},
{"unknownTy", {0.0, 0.0}}};
M = std::make_unique<Module>("M", Ctx);
FunctionType *FTy = FunctionType::get(
Type::getInt32Ty(Ctx), {Type::getInt32Ty(Ctx), Type::getInt32Ty(Ctx)},
false);
F = Function::Create(FTy, Function::ExternalLinkage, "f", M.get());
BB = BasicBlock::Create(Ctx, "entry", F);
Argument *Arg = F->getArg(0);
llvm::Value *Const = ConstantInt::get(Type::getInt32Ty(Ctx), 42);
Add = BinaryOperator::CreateAdd(Arg, Const, "add", BB);
Ret = ReturnInst::Create(Ctx, Add, BB);
auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
EXPECT_TRUE(static_cast<bool>(Result));
Emb = std::move(*Result);
}
};
TEST(IR2VecTest, GetInstVecMap) {
GetterTestEnv Env;
const auto &InstMap = Env.Emb->getInstVecMap();
EXPECT_EQ(InstMap.size(), 2u);
EXPECT_TRUE(InstMap.count(Env.Add));
EXPECT_TRUE(InstMap.count(Env.Ret));
EXPECT_EQ(InstMap.at(Env.Add).size(), 2u);
EXPECT_EQ(InstMap.at(Env.Ret).size(), 2u);
// Check values for add: {1.29, 2.31}
EXPECT_THAT(InstMap.at(Env.Add),
ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6)));
// Check values for ret: {0.0, 0.}; Neither ret nor voidTy are present in
// vocab
EXPECT_THAT(InstMap.at(Env.Ret), ElementsAre(0.0, 0.0));
}
TEST(IR2VecTest, GetBBVecMap) {
GetterTestEnv Env;
const auto &BBMap = Env.Emb->getBBVecMap();
EXPECT_EQ(BBMap.size(), 1u);
EXPECT_TRUE(BBMap.count(Env.BB));
EXPECT_EQ(BBMap.at(Env.BB).size(), 2u);
// BB vector should be sum of add and ret: {1.29, 2.31} + {0.0, 0.0} =
// {1.29, 2.31}
EXPECT_THAT(BBMap.at(Env.BB),
ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6)));
}
TEST(IR2VecTest, GetBBVector) {
GetterTestEnv Env;
const auto &BBVec = Env.Emb->getBBVector(*Env.BB);
EXPECT_EQ(BBVec.size(), 2u);
EXPECT_THAT(BBVec,
ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6)));
}
TEST(IR2VecTest, GetFunctionVector) {
GetterTestEnv Env;
const auto &FuncVec = Env.Emb->getFunctionVector();
EXPECT_EQ(FuncVec.size(), 2u);
// Function vector should match BB vector (only one BB): {1.29, 2.31}
EXPECT_THAT(FuncVec,
ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6)));
}
} // end anonymous namespace