blob: fa38c35796a0e2a56d7322a2e36738852072aa3d [file] [log] [blame]
//===- IR2Vec.cpp - Implementation of IR2Vec -----------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
// Exceptions. See the LICENSE file for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
///
/// \file
/// This file implements the IR2Vec algorithm.
///
//===----------------------------------------------------------------------===//
#include "llvm/Analysis/IR2Vec.h"
#include "llvm/ADT/DepthFirstIterator.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/IR/CFG.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/PassManager.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/Errc.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/Format.h"
#include "llvm/Support/MemoryBuffer.h"
using namespace llvm;
using namespace ir2vec;
#define DEBUG_TYPE "ir2vec"
STATISTIC(VocabMissCounter,
"Number of lookups to entites not present in the vocabulary");
namespace llvm {
namespace ir2vec {
static cl::OptionCategory IR2VecCategory("IR2Vec Options");
// FIXME: Use a default vocab when not specified
static cl::opt<std::string>
VocabFile("ir2vec-vocab-path", cl::Optional,
cl::desc("Path to the vocabulary file for IR2Vec"), cl::init(""),
cl::cat(IR2VecCategory));
cl::opt<float> OpcWeight("ir2vec-opc-weight", cl::Optional, cl::init(1.0),
cl::desc("Weight for opcode embeddings"),
cl::cat(IR2VecCategory));
cl::opt<float> TypeWeight("ir2vec-type-weight", cl::Optional, cl::init(0.5),
cl::desc("Weight for type embeddings"),
cl::cat(IR2VecCategory));
cl::opt<float> ArgWeight("ir2vec-arg-weight", cl::Optional, cl::init(0.2),
cl::desc("Weight for argument embeddings"),
cl::cat(IR2VecCategory));
} // namespace ir2vec
} // namespace llvm
AnalysisKey IR2VecVocabAnalysis::Key;
namespace llvm::json {
inline bool fromJSON(const llvm::json::Value &E, Embedding &Out,
llvm::json::Path P) {
std::vector<double> TempOut;
if (!llvm::json::fromJSON(E, TempOut, P))
return false;
Out = Embedding(std::move(TempOut));
return true;
}
} // namespace llvm::json
// ==----------------------------------------------------------------------===//
// Embedding
//===----------------------------------------------------------------------===//
Embedding &Embedding::operator+=(const Embedding &RHS) {
assert(this->size() == RHS.size() && "Vectors must have the same dimension");
std::transform(this->begin(), this->end(), RHS.begin(), this->begin(),
std::plus<double>());
return *this;
}
Embedding &Embedding::operator-=(const Embedding &RHS) {
assert(this->size() == RHS.size() && "Vectors must have the same dimension");
std::transform(this->begin(), this->end(), RHS.begin(), this->begin(),
std::minus<double>());
return *this;
}
Embedding &Embedding::scaleAndAdd(const Embedding &Src, float Factor) {
assert(this->size() == Src.size() && "Vectors must have the same dimension");
for (size_t Itr = 0; Itr < this->size(); ++Itr)
(*this)[Itr] += Src[Itr] * Factor;
return *this;
}
bool Embedding::approximatelyEquals(const Embedding &RHS,
double Tolerance) const {
assert(this->size() == RHS.size() && "Vectors must have the same dimension");
for (size_t Itr = 0; Itr < this->size(); ++Itr)
if (std::abs((*this)[Itr] - RHS[Itr]) > Tolerance)
return false;
return true;
}
// ==----------------------------------------------------------------------===//
// Embedder and its subclasses
//===----------------------------------------------------------------------===//
Embedder::Embedder(const Function &F, const Vocab &Vocabulary)
: F(F), Vocabulary(Vocabulary),
Dimension(Vocabulary.begin()->second.size()), OpcWeight(::OpcWeight),
TypeWeight(::TypeWeight), ArgWeight(::ArgWeight) {}
Expected<std::unique_ptr<Embedder>>
Embedder::create(IR2VecKind Mode, const Function &F, const Vocab &Vocabulary) {
switch (Mode) {
case IR2VecKind::Symbolic:
return std::make_unique<SymbolicEmbedder>(F, Vocabulary);
}
return make_error<StringError>("Unknown IR2VecKind", errc::invalid_argument);
}
// FIXME: Currently lookups are string based. Use numeric Keys
// for efficiency
Embedding Embedder::lookupVocab(const std::string &Key) const {
Embedding Vec(Dimension, 0);
// FIXME: Use zero vectors in vocab and assert failure for
// unknown entities rather than silently returning zeroes here.
auto It = Vocabulary.find(Key);
if (It != Vocabulary.end())
return It->second;
LLVM_DEBUG(errs() << "cannot find key in map : " << Key << "\n");
++VocabMissCounter;
return Vec;
}
const InstEmbeddingsMap &Embedder::getInstVecMap() const {
if (InstVecMap.empty())
computeEmbeddings();
return InstVecMap;
}
const BBEmbeddingsMap &Embedder::getBBVecMap() const {
if (BBVecMap.empty())
computeEmbeddings();
return BBVecMap;
}
const Embedding &Embedder::getBBVector(const BasicBlock &BB) const {
auto It = BBVecMap.find(&BB);
if (It != BBVecMap.end())
return It->second;
computeEmbeddings(BB);
return BBVecMap[&BB];
}
const Embedding &Embedder::getFunctionVector() const {
// Currently, we always (re)compute the embeddings for the function.
// This is cheaper than caching the vector.
computeEmbeddings();
return FuncVector;
}
#define RETURN_LOOKUP_IF(CONDITION, KEY_STR) \
if (CONDITION) \
return lookupVocab(KEY_STR);
Embedding SymbolicEmbedder::getTypeEmbedding(const Type *Ty) const {
RETURN_LOOKUP_IF(Ty->isVoidTy(), "voidTy");
RETURN_LOOKUP_IF(Ty->isFloatingPointTy(), "floatTy");
RETURN_LOOKUP_IF(Ty->isIntegerTy(), "integerTy");
RETURN_LOOKUP_IF(Ty->isFunctionTy(), "functionTy");
RETURN_LOOKUP_IF(Ty->isStructTy(), "structTy");
RETURN_LOOKUP_IF(Ty->isArrayTy(), "arrayTy");
RETURN_LOOKUP_IF(Ty->isPointerTy(), "pointerTy");
RETURN_LOOKUP_IF(Ty->isVectorTy(), "vectorTy");
RETURN_LOOKUP_IF(Ty->isEmptyTy(), "emptyTy");
RETURN_LOOKUP_IF(Ty->isLabelTy(), "labelTy");
RETURN_LOOKUP_IF(Ty->isTokenTy(), "tokenTy");
RETURN_LOOKUP_IF(Ty->isMetadataTy(), "metadataTy");
return lookupVocab("unknownTy");
}
Embedding SymbolicEmbedder::getOperandEmbedding(const Value *Op) const {
RETURN_LOOKUP_IF(isa<Function>(Op), "function");
RETURN_LOOKUP_IF(isa<PointerType>(Op->getType()), "pointer");
RETURN_LOOKUP_IF(isa<Constant>(Op), "constant");
return lookupVocab("variable");
}
#undef RETURN_LOOKUP_IF
void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
Embedding BBVector(Dimension, 0);
// We consider only the non-debug and non-pseudo instructions
for (const auto &I : BB.instructionsWithoutDebug()) {
Embedding InstVector(Dimension, 0);
const auto OpcVec = lookupVocab(I.getOpcodeName());
InstVector.scaleAndAdd(OpcVec, OpcWeight);
// FIXME: Currently lookups are string based. Use numeric Keys
// for efficiency.
const auto Type = I.getType();
const auto TypeVec = getTypeEmbedding(Type);
InstVector.scaleAndAdd(TypeVec, TypeWeight);
for (const auto &Op : I.operands()) {
const auto OperandVec = getOperandEmbedding(Op.get());
InstVector.scaleAndAdd(OperandVec, ArgWeight);
}
InstVecMap[&I] = InstVector;
BBVector += InstVector;
}
BBVecMap[&BB] = BBVector;
}
void SymbolicEmbedder::computeEmbeddings() const {
if (F.isDeclaration())
return;
// Consider only the basic blocks that are reachable from entry
for (const BasicBlock *BB : depth_first(&F)) {
computeEmbeddings(*BB);
FuncVector += BBVecMap[BB];
}
}
// ==----------------------------------------------------------------------===//
// IR2VecVocabResult and IR2VecVocabAnalysis
//===----------------------------------------------------------------------===//
IR2VecVocabResult::IR2VecVocabResult(ir2vec::Vocab &&Vocabulary)
: Vocabulary(std::move(Vocabulary)), Valid(true) {}
const ir2vec::Vocab &IR2VecVocabResult::getVocabulary() const {
assert(Valid && "IR2Vec Vocabulary is invalid");
return Vocabulary;
}
unsigned IR2VecVocabResult::getDimension() const {
assert(Valid && "IR2Vec Vocabulary is invalid");
return Vocabulary.begin()->second.size();
}
// For now, assume vocabulary is stable unless explicitly invalidated.
bool IR2VecVocabResult::invalidate(
Module &M, const PreservedAnalyses &PA,
ModuleAnalysisManager::Invalidator &Inv) const {
auto PAC = PA.getChecker<IR2VecVocabAnalysis>();
return !(PAC.preservedWhenStateless());
}
// FIXME: Make this optional. We can avoid file reads
// by auto-generating a default vocabulary during the build time.
Error IR2VecVocabAnalysis::readVocabulary() {
auto BufOrError = MemoryBuffer::getFileOrSTDIN(VocabFile, /*IsText=*/true);
if (!BufOrError)
return createFileError(VocabFile, BufOrError.getError());
auto Content = BufOrError.get()->getBuffer();
json::Path::Root Path("");
Expected<json::Value> ParsedVocabValue = json::parse(Content);
if (!ParsedVocabValue)
return ParsedVocabValue.takeError();
bool Res = json::fromJSON(*ParsedVocabValue, Vocabulary, Path);
if (!Res)
return createStringError(errc::illegal_byte_sequence,
"Unable to parse the vocabulary");
if (Vocabulary.empty())
return createStringError(errc::illegal_byte_sequence,
"Vocabulary is empty");
unsigned Dim = Vocabulary.begin()->second.size();
if (Dim == 0)
return createStringError(errc::illegal_byte_sequence,
"Dimension of vocabulary is zero");
if (!std::all_of(Vocabulary.begin(), Vocabulary.end(),
[Dim](const std::pair<StringRef, Embedding> &Entry) {
return Entry.second.size() == Dim;
}))
return createStringError(
errc::illegal_byte_sequence,
"All vectors in the vocabulary are not of the same dimension");
return Error::success();
}
IR2VecVocabAnalysis::IR2VecVocabAnalysis(const Vocab &Vocabulary)
: Vocabulary(Vocabulary) {}
IR2VecVocabAnalysis::IR2VecVocabAnalysis(Vocab &&Vocabulary)
: Vocabulary(std::move(Vocabulary)) {}
void IR2VecVocabAnalysis::emitError(Error Err, LLVMContext &Ctx) {
handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
Ctx.emitError("Error reading vocabulary: " + EI.message());
});
}
IR2VecVocabAnalysis::Result
IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
auto Ctx = &M.getContext();
// FIXME: Scale the vocabulary once. This would avoid scaling per use later.
// If vocabulary is already populated by the constructor, use it.
if (!Vocabulary.empty())
return IR2VecVocabResult(std::move(Vocabulary));
// Otherwise, try to read from the vocabulary file.
if (VocabFile.empty()) {
// FIXME: Use default vocabulary
Ctx->emitError("IR2Vec vocabulary file path not specified");
return IR2VecVocabResult(); // Return invalid result
}
if (auto Err = readVocabulary()) {
emitError(std::move(Err), *Ctx);
return IR2VecVocabResult();
}
return IR2VecVocabResult(std::move(Vocabulary));
}
// ==----------------------------------------------------------------------===//
// IR2VecPrinterPass
//===----------------------------------------------------------------------===//
void IR2VecPrinterPass::printVector(const Embedding &Vec) const {
OS << " [";
for (const auto &Elem : Vec)
OS << " " << format("%.2f", Elem) << " ";
OS << "]\n";
}
PreservedAnalyses IR2VecPrinterPass::run(Module &M,
ModuleAnalysisManager &MAM) {
auto IR2VecVocabResult = MAM.getResult<IR2VecVocabAnalysis>(M);
assert(IR2VecVocabResult.isValid() && "IR2Vec Vocabulary is invalid");
auto Vocab = IR2VecVocabResult.getVocabulary();
for (Function &F : M) {
Expected<std::unique_ptr<Embedder>> EmbOrErr =
Embedder::create(IR2VecKind::Symbolic, F, Vocab);
if (auto Err = EmbOrErr.takeError()) {
handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
OS << "Error creating IR2Vec embeddings: " << EI.message() << "\n";
});
continue;
}
std::unique_ptr<Embedder> Emb = std::move(*EmbOrErr);
OS << "IR2Vec embeddings for function " << F.getName() << ":\n";
OS << "Function vector: ";
printVector(Emb->getFunctionVector());
OS << "Basic block vectors:\n";
const auto &BBMap = Emb->getBBVecMap();
for (const BasicBlock &BB : F) {
auto It = BBMap.find(&BB);
if (It != BBMap.end()) {
OS << "Basic block: " << BB.getName() << ":\n";
printVector(It->second);
}
}
OS << "Instruction vectors:\n";
const auto &InstMap = Emb->getInstVecMap();
for (const BasicBlock &BB : F) {
for (const Instruction &I : BB) {
auto It = InstMap.find(&I);
if (It != InstMap.end()) {
OS << "Instruction: ";
I.print(OS);
printVector(It->second);
}
}
}
}
return PreservedAnalyses::all();
}