blob: 0bc101631e81361008b191265bacef93551eee1a [file] [log] [blame]
//===----- VmkitAOTGC.cpp - Support for Ahead of Time Compiler GC -------===//
//
// The Vmkit project
//
// This file is distributed under the University of Illinois Open Source
// License. See LICENSE.TXT for details.
//
//===----------------------------------------------------------------------===//
#include "llvm/IR/Constants.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Type.h"
#include "llvm/CodeGen/GCs.h"
#include "llvm/CodeGen/GCStrategy.h"
#include "llvm/CodeGen/AsmPrinter.h"
#include "llvm/CodeGen/GCMetadataPrinter.h"
#include "llvm/CodeGen/MachineFunction.h"
#include "llvm/CodeGen/MachineInstrBuilder.h"
#include "llvm/IR/Module.h"
#include "llvm/MC/MCAsmInfo.h"
#include "llvm/MC/MCContext.h"
#include "llvm/MC/MCExpr.h"
#include "llvm/MC/MCSymbol.h"
#include "llvm/MC/MCStreamer.h"
#include "llvm/Target/Mangler.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/Target/TargetInstrInfo.h"
#include "llvm/Target/TargetLoweringObjectFile.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/Support/Compiler.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/FormattedStream.h"
#include "llvm/Support/raw_ostream.h"
#include <cctype>
#include <cstdio>
using namespace llvm;
namespace {
class VmkitAOTGC : public GCStrategy {
public:
VmkitAOTGC();
virtual bool findCustomSafePoints(GCFunctionInfo& FI, MachineFunction& MF);
};
}
static GCRegistry::Add<VmkitAOTGC>
X("vmkit", "Vmkit GC for AOT-generated functions");
VmkitAOTGC::VmkitAOTGC() {
CustomSafePoints = true;
UsesMetadata = true;
}
static MCSymbol *InsertLabel(MachineBasicBlock &MBB,
MachineBasicBlock::iterator MI,
DebugLoc DL) {
const TargetInstrInfo* TII = MBB.getParent()->getTarget().getInstrInfo();
MCSymbol *Label = MBB.getParent()->getContext().CreateTempSymbol();
BuildMI(MBB, MI, DL, TII->get(TargetOpcode::GC_LABEL)).addSym(Label);
return Label;
}
bool VmkitAOTGC::findCustomSafePoints(GCFunctionInfo& FI, MachineFunction &MF) {
for (MachineFunction::iterator BBI = MF.begin(),
BBE = MF.end(); BBI != BBE; ++BBI) {
for (MachineBasicBlock::iterator MI = BBI->begin(),
ME = BBI->end(); MI != ME; ++MI) {
if (MI->getDesc().isCall()) {
MachineBasicBlock::iterator RAI = MI; ++RAI;
MCSymbol* Label = InsertLabel(*MI->getParent(), RAI, MI->getDebugLoc());
FI.addSafePoint(GC::PostCall, Label, MI->getDebugLoc());
} else if (MI->getDebugLoc().getCol() == 1) {
MCSymbol* Label = InsertLabel(*MI->getParent(), MI, MI->getDebugLoc());
FI.addSafePoint(GC::Loop, Label, MI->getDebugLoc());
}
}
}
return false;
}
namespace {
class VmkitAOTGCMetadataPrinter : public GCMetadataPrinter {
public:
void beginAssembly(AsmPrinter &AP);
void finishAssembly(AsmPrinter &AP);
};
}
static GCMetadataPrinterRegistry::Add<VmkitAOTGCMetadataPrinter>
Y("vmkit", "Vmkit GC for AOT-generated functions");
void VmkitAOTGCMetadataPrinter::beginAssembly(AsmPrinter &AP) {
}
static bool isAcceptableChar(char C) {
if ((C < 'a' || C > 'z') &&
(C < 'A' || C > 'Z') &&
(C < '0' || C > '9') &&
C != '_' && C != '$' && C != '@') {
return false;
}
return true;
}
static char HexDigit(int V) {
return V < 10 ? V+'0' : V+'A'-10;
}
static void MangleLetter(SmallVectorImpl<char> &OutName, unsigned char C) {
OutName.push_back('_');
OutName.push_back(HexDigit(C >> 4));
OutName.push_back(HexDigit(C & 15));
OutName.push_back('_');
}
static void EmitVmkitGlobal(const Module &M, AsmPrinter &AP, const char *Id) {
const std::string &MId = M.getModuleIdentifier();
std::string SymName;
SymName += "vmkit";
size_t Letter = SymName.size();
SymName += MId;
SymName += "__";
SymName += Id;
// Capitalize the first letter of the module name.
SymName[Letter] = toupper(SymName[Letter]);
SmallString<128> TmpStr;
AP.Mang->getNameWithPrefix(TmpStr, SymName);
SmallString<128> FinalStr;
for (unsigned i = 0, e = TmpStr.size(); i != e; ++i) {
if (!isAcceptableChar(TmpStr[i])) {
MangleLetter(FinalStr, TmpStr[i]);
} else {
FinalStr.push_back(TmpStr[i]);
}
}
MCSymbol *Sym = AP.OutContext.GetOrCreateSymbol(FinalStr);
AP.OutStreamer.EmitSymbolAttribute(Sym, MCSA_Global);
AP.OutStreamer.EmitLabel(Sym);
}
static bool methodNameMatches(StringRef compiledName,
ConstantDataArray* name,
ConstantDataArray* type) {
uint32_t size = compiledName.size();
std::string str;
for (uint32_t i = 0; i < name->getNumElements(); ++i) {
ConstantInt* charInt = cast<ConstantInt>(name->getElementAsConstant(i));
int16_t cur = charInt->getZExtValue();
if (cur == '/') {
str += '_';
} else if (cur == '_') {
str += "_1";
} else if (cur == '<') {
str += "_0003C";
} else if (cur == '>') {
str += "_0003E";
} else {
str += (char)cur;
}
}
for (uint32_t i = 0; i < type->getNumElements(); ++i) {
ConstantInt* charInt = cast<ConstantInt>(type->getElementAsConstant(i));
int16_t cur = charInt->getZExtValue();
if (cur == '(') {
str += "__";
} else if (cur == '/') {
str += '_';
} else if (cur == '_') {
str += "_1";
} else if (cur == '$') {
str += "_00024";
} else if (cur == ';') {
str += "_2";
} else if (cur == '[') {
str += "_3";
} else if (cur == ')') {
break;
} else {
str += (char)cur;
}
}
if (str.length() > size) return false;
if (str.compare(compiledName) == 0) return true;
str += 'S';
if (str.compare(compiledName) == 0) return true;
return false;
}
Constant* FindMetadata(const Function& F) {
LLVMContext& context = F.getParent()->getContext();
for (Value::const_use_iterator I = F.use_begin(), E = F.use_end(); I != E; ++I) {
if (const Constant* C = dyn_cast<Constant>(*I)) {
if (PointerType* PTy = dyn_cast<PointerType>(C->getType())) {
if (isa<IntegerType>(PTy->getContainedType(0))) {
// We have found the bitcast constant that casts the method in a i8*
for (Value::const_use_iterator CI = C->use_begin(), CE = C->use_end(); CI != CE; ++CI) {
if (StructType* STy = dyn_cast<StructType>((*CI)->getType())) {
if (STy->getName().equals("JavaMethod")) {
const Constant* Method = dyn_cast<Constant>(*CI);
const Constant* Array = dyn_cast<Constant>(*((*CI)->use_begin()));
Constant* VirtualMethods = dyn_cast<Constant>(const_cast<User*>((*(Array->use_begin()))));
uint32_t index = 0;
for (; index < Array->getNumOperands(); index++) {
if (Array->getOperand(index) == Method) break;
}
assert(index != Array->getNumOperands());
Constant* GEPs[2] = { ConstantInt::get(Type::getInt32Ty(context), 0),
ConstantInt::get(Type::getInt32Ty(context), index) };
return ConstantExpr::getGetElementPtr(VirtualMethods, GEPs, 2);
}
}
}
}
}
}
}
StringRef name = F.getName();
if (name.startswith("JnJVM")) {
// Metadata for customized methods.
std::string methods = name.substr(0, name.find("__"));
std::string methodName = name.substr(methods.rfind('_') + 1);
methodName = methodName.substr(0, methodName.rfind("__"));
methods = methods.substr(6, methods.rfind('_') - 5);
methods = methods + "VirtualMethods";
Constant* VirtualMethods = cast<Constant>(F.getParent()->getNamedValue(methods));
assert(VirtualMethods);
Constant* MethodsArray = cast<Constant>(VirtualMethods->getOperand(0));
for (uint32_t index = 0; index < MethodsArray->getNumOperands(); index++) {
Constant* method = cast<Constant>(MethodsArray->getOperand(index));
Constant* namePtr = cast<ConstantExpr>(method->getOperand(5));
namePtr = cast<Constant>(namePtr->getOperand(0));
namePtr = cast<Constant>(namePtr->getOperand(0));
ConstantDataArray* name = cast<ConstantDataArray>(namePtr->getOperand(1));
Constant* typePtr = cast<ConstantExpr>(method->getOperand(6));
typePtr = cast<Constant>(typePtr->getOperand(0));
typePtr = cast<Constant>(typePtr->getOperand(0));
ConstantDataArray* type = cast<ConstantDataArray>(typePtr->getOperand(1));
if (methodNameMatches(methodName, name, type)) {
Constant* GEPs[2] = { ConstantInt::get(Type::getInt32Ty(context), 0),
ConstantInt::get(Type::getInt32Ty(context), index) };
return ConstantExpr::getGetElementPtr(VirtualMethods, GEPs, 2);
}
}
assert(0 && "Should have found a JavaMethod");
}
return NULL;
}
/// emitAssembly - Print the frametable. The ocaml frametable format is thus:
///
/// extern "C" struct align(sizeof(word_t)) {
/// uint32_t NumDescriptors;
/// struct align(sizeof(word_t)) {
/// void *ReturnAddress;
/// void *Metadata;
/// uint16_t BytecodeIndex;
/// uint16_t FrameSize;
/// uint16_t NumLiveOffsets;
/// uint16_t LiveOffsets[NumLiveOffsets];
/// } Descriptors[NumDescriptors];
/// } vmkit${module}__frametable;
///
/// Note that this precludes programs from stack frames larger than 64K
/// (FrameSize and LiveOffsets would overflow). FrameTablePrinter will abort if
/// either condition is detected in a function which uses the GC.
///
void VmkitAOTGCMetadataPrinter::finishAssembly(AsmPrinter &AP) {
unsigned IntPtrSize = AP.TM.getDataLayout()->getPointerSize(0);
AP.OutStreamer.SwitchSection(AP.getObjFileLowering().getDataSection());
AP.EmitAlignment(IntPtrSize == 4 ? 2 : 3);
EmitVmkitGlobal(getModule(), AP, "frametable");
int NumMethodFrames = 0;
for (iterator I = begin(), IE = end(); I != IE; ++I) {
NumMethodFrames++;
}
AP.EmitInt32(NumMethodFrames);
AP.EmitAlignment(IntPtrSize == 4 ? 2 : 3);
for (iterator I = begin(), IE = end(); I != IE; ++I) {
GCFunctionInfo &FI = **I;
Constant* Metadata = FindMetadata(FI.getFunction());
int NumDescriptors = 0;
for (GCFunctionInfo::iterator J = FI.begin(), JE = FI.end(); J != JE; ++J) {
NumDescriptors++;
}
if (NumDescriptors >= 1<<16) {
// Very rude!
report_fatal_error(" Too much descriptor for J3 AOT GC");
}
AP.EmitInt32(NumDescriptors);
AP.EmitAlignment(IntPtrSize == 4 ? 2 : 3);
uint64_t FrameSize = FI.getFrameSize();
if (FrameSize >= 1<<16) {
// Very rude!
report_fatal_error("Function '" + FI.getFunction().getName() +
"' is too large for the Vmkit AOT GC! "
"Frame size " + Twine(FrameSize) + ">= 65536.\n"
"(" + Twine(uintptr_t(&FI)) + ")");
}
AP.OutStreamer.AddComment("live roots for " +
Twine(FI.getFunction().getName()));
AP.OutStreamer.AddBlankLine();
for (GCFunctionInfo::iterator J = FI.begin(), JE = FI.end(); J != JE; ++J) {
size_t LiveCount = FI.live_size(J);
if (LiveCount >= 1<<16) {
// Very rude!
report_fatal_error("Function '" + FI.getFunction().getName() +
"' is too large for the Vmkit AOT GC! "
"Live root count "+Twine(LiveCount)+" >= 65536.");
}
DebugLoc DL = J->Loc;
uint32_t sourceIndex = DL.getLine();
// Metadata
if (Metadata != NULL) {
AP.EmitGlobalConstant(Metadata);
} else {
AP.EmitInt32(0);
if (IntPtrSize == 8) {
AP.EmitInt32(0);
}
}
// Return address
const MCExpr* address = MCSymbolRefExpr::Create(J->Label, AP.OutStreamer.getContext());
if (DL.getCol() == 1) {
const MCExpr* one = MCConstantExpr::Create(1, AP.OutStreamer.getContext());
address = MCBinaryExpr::CreateAdd(address, one, AP.OutStreamer.getContext());
}
AP.OutStreamer.EmitValue(address, IntPtrSize, 0);
AP.EmitInt16(sourceIndex);
AP.EmitInt16(FrameSize);
AP.EmitInt16(LiveCount);
for (GCFunctionInfo::live_iterator K = FI.live_begin(J),
KE = FI.live_end(J); K != KE; ++K) {
if (K->StackOffset >= 1<<16) {
// Very rude!
report_fatal_error(
"GC root stack offset is outside of fixed stack frame and out "
"of range for ocaml GC!");
}
AP.EmitInt16(K->StackOffset);
}
AP.EmitAlignment(IntPtrSize == 4 ? 2 : 3);
}
}
}