blob: 706e54f064aa6ed3497bc90f33998bd7a6803a0c [file] [log] [blame]
//===- LoweringPrepare.cpp - pareparation work for LLVM lowering ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "clang/AST/ASTContext.h"
#include "clang/Basic/Module.h"
#include "clang/Basic/TargetInfo.h"
#include "clang/CIR/Dialect/Builder/CIRBaseBuilder.h"
#include "clang/CIR/Dialect/IR/CIRDialect.h"
#include "clang/CIR/Dialect/IR/CIROpsEnums.h"
#include "clang/CIR/Dialect/Passes.h"
#include "clang/CIR/MissingFeatures.h"
#include "llvm/Support/Path.h"
#include <memory>
using namespace mlir;
using namespace cir;
static SmallString<128> getTransformedFileName(mlir::ModuleOp mlirModule) {
SmallString<128> fileName;
if (mlirModule.getSymName())
fileName = llvm::sys::path::filename(mlirModule.getSymName()->str());
if (fileName.empty())
fileName = "<null>";
for (size_t i = 0; i < fileName.size(); ++i) {
// Replace everything that's not [a-zA-Z0-9._] with a _. This set happens
// to be the set of C preprocessing numbers.
if (!clang::isPreprocessingNumberBody(fileName[i]))
fileName[i] = '_';
}
return fileName;
}
/// Return the FuncOp called by `callOp`.
static cir::FuncOp getCalledFunction(cir::CallOp callOp) {
mlir::SymbolRefAttr sym = llvm::dyn_cast_if_present<mlir::SymbolRefAttr>(
callOp.getCallableForCallee());
if (!sym)
return nullptr;
return dyn_cast_or_null<cir::FuncOp>(
mlir::SymbolTable::lookupNearestSymbolFrom(callOp, sym));
}
namespace {
struct LoweringPreparePass : public LoweringPrepareBase<LoweringPreparePass> {
LoweringPreparePass() = default;
void runOnOperation() override;
void runOnOp(mlir::Operation *op);
void lowerCastOp(cir::CastOp op);
void lowerComplexDivOp(cir::ComplexDivOp op);
void lowerComplexMulOp(cir::ComplexMulOp op);
void lowerUnaryOp(cir::UnaryOp op);
void lowerGlobalOp(cir::GlobalOp op);
void lowerArrayDtor(cir::ArrayDtor op);
void lowerArrayCtor(cir::ArrayCtor op);
/// Build the function that initializes the specified global
cir::FuncOp buildCXXGlobalVarDeclInitFunc(cir::GlobalOp op);
/// Build a module init function that calls all the dynamic initializers.
void buildCXXGlobalInitFunc();
/// Materialize global ctor/dtor list
void buildGlobalCtorDtorList();
cir::FuncOp buildRuntimeFunction(
mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc,
cir::FuncType type,
cir::GlobalLinkageKind linkage = cir::GlobalLinkageKind::ExternalLinkage);
cir::GlobalOp buildRuntimeVariable(
mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc,
mlir::Type type,
cir::GlobalLinkageKind linkage = cir::GlobalLinkageKind::ExternalLinkage,
cir::VisibilityKind visibility = cir::VisibilityKind::Default);
///
/// AST related
/// -----------
clang::ASTContext *astCtx;
/// Tracks current module.
mlir::ModuleOp mlirModule;
/// Tracks existing dynamic initializers.
llvm::StringMap<uint32_t> dynamicInitializerNames;
llvm::SmallVector<cir::FuncOp> dynamicInitializers;
/// List of ctors and their priorities to be called before main()
llvm::SmallVector<std::pair<std::string, uint32_t>, 4> globalCtorList;
void setASTContext(clang::ASTContext *c) { astCtx = c; }
};
} // namespace
cir::GlobalOp LoweringPreparePass::buildRuntimeVariable(
mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc,
mlir::Type type, cir::GlobalLinkageKind linkage,
cir::VisibilityKind visibility) {
cir::GlobalOp g = dyn_cast_or_null<cir::GlobalOp>(
mlir::SymbolTable::lookupNearestSymbolFrom(
mlirModule, mlir::StringAttr::get(mlirModule->getContext(), name)));
if (!g) {
g = cir::GlobalOp::create(builder, loc, name, type);
g.setLinkageAttr(
cir::GlobalLinkageKindAttr::get(builder.getContext(), linkage));
mlir::SymbolTable::setSymbolVisibility(
g, mlir::SymbolTable::Visibility::Private);
g.setGlobalVisibilityAttr(
cir::VisibilityAttr::get(builder.getContext(), visibility));
}
return g;
}
cir::FuncOp LoweringPreparePass::buildRuntimeFunction(
mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc,
cir::FuncType type, cir::GlobalLinkageKind linkage) {
cir::FuncOp f = dyn_cast_or_null<FuncOp>(SymbolTable::lookupNearestSymbolFrom(
mlirModule, StringAttr::get(mlirModule->getContext(), name)));
if (!f) {
f = builder.create<cir::FuncOp>(loc, name, type);
f.setLinkageAttr(
cir::GlobalLinkageKindAttr::get(builder.getContext(), linkage));
mlir::SymbolTable::setSymbolVisibility(
f, mlir::SymbolTable::Visibility::Private);
assert(!cir::MissingFeatures::opFuncExtraAttrs());
}
return f;
}
static mlir::Value lowerScalarToComplexCast(mlir::MLIRContext &ctx,
cir::CastOp op) {
cir::CIRBaseBuilderTy builder(ctx);
builder.setInsertionPoint(op);
mlir::Value src = op.getSrc();
mlir::Value imag = builder.getNullValue(src.getType(), op.getLoc());
return builder.createComplexCreate(op.getLoc(), src, imag);
}
static mlir::Value lowerComplexToScalarCast(mlir::MLIRContext &ctx,
cir::CastOp op,
cir::CastKind elemToBoolKind) {
cir::CIRBaseBuilderTy builder(ctx);
builder.setInsertionPoint(op);
mlir::Value src = op.getSrc();
if (!mlir::isa<cir::BoolType>(op.getType()))
return builder.createComplexReal(op.getLoc(), src);
// Complex cast to bool: (bool)(a+bi) => (bool)a || (bool)b
mlir::Value srcReal = builder.createComplexReal(op.getLoc(), src);
mlir::Value srcImag = builder.createComplexImag(op.getLoc(), src);
cir::BoolType boolTy = builder.getBoolTy();
mlir::Value srcRealToBool =
builder.createCast(op.getLoc(), elemToBoolKind, srcReal, boolTy);
mlir::Value srcImagToBool =
builder.createCast(op.getLoc(), elemToBoolKind, srcImag, boolTy);
return builder.createLogicalOr(op.getLoc(), srcRealToBool, srcImagToBool);
}
static mlir::Value lowerComplexToComplexCast(mlir::MLIRContext &ctx,
cir::CastOp op,
cir::CastKind scalarCastKind) {
CIRBaseBuilderTy builder(ctx);
builder.setInsertionPoint(op);
mlir::Value src = op.getSrc();
auto dstComplexElemTy =
mlir::cast<cir::ComplexType>(op.getType()).getElementType();
mlir::Value srcReal = builder.createComplexReal(op.getLoc(), src);
mlir::Value srcImag = builder.createComplexImag(op.getLoc(), src);
mlir::Value dstReal = builder.createCast(op.getLoc(), scalarCastKind, srcReal,
dstComplexElemTy);
mlir::Value dstImag = builder.createCast(op.getLoc(), scalarCastKind, srcImag,
dstComplexElemTy);
return builder.createComplexCreate(op.getLoc(), dstReal, dstImag);
}
void LoweringPreparePass::lowerCastOp(cir::CastOp op) {
mlir::MLIRContext &ctx = getContext();
mlir::Value loweredValue = [&]() -> mlir::Value {
switch (op.getKind()) {
case cir::CastKind::float_to_complex:
case cir::CastKind::int_to_complex:
return lowerScalarToComplexCast(ctx, op);
case cir::CastKind::float_complex_to_real:
case cir::CastKind::int_complex_to_real:
return lowerComplexToScalarCast(ctx, op, op.getKind());
case cir::CastKind::float_complex_to_bool:
return lowerComplexToScalarCast(ctx, op, cir::CastKind::float_to_bool);
case cir::CastKind::int_complex_to_bool:
return lowerComplexToScalarCast(ctx, op, cir::CastKind::int_to_bool);
case cir::CastKind::float_complex:
return lowerComplexToComplexCast(ctx, op, cir::CastKind::floating);
case cir::CastKind::float_complex_to_int_complex:
return lowerComplexToComplexCast(ctx, op, cir::CastKind::float_to_int);
case cir::CastKind::int_complex:
return lowerComplexToComplexCast(ctx, op, cir::CastKind::integral);
case cir::CastKind::int_complex_to_float_complex:
return lowerComplexToComplexCast(ctx, op, cir::CastKind::int_to_float);
default:
return nullptr;
}
}();
if (loweredValue) {
op.replaceAllUsesWith(loweredValue);
op.erase();
}
}
static mlir::Value buildComplexBinOpLibCall(
LoweringPreparePass &pass, CIRBaseBuilderTy &builder,
llvm::StringRef (*libFuncNameGetter)(llvm::APFloat::Semantics),
mlir::Location loc, cir::ComplexType ty, mlir::Value lhsReal,
mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag) {
cir::FPTypeInterface elementTy =
mlir::cast<cir::FPTypeInterface>(ty.getElementType());
llvm::StringRef libFuncName = libFuncNameGetter(
llvm::APFloat::SemanticsToEnum(elementTy.getFloatSemantics()));
llvm::SmallVector<mlir::Type, 4> libFuncInputTypes(4, elementTy);
cir::FuncType libFuncTy = cir::FuncType::get(libFuncInputTypes, ty);
// Insert a declaration for the runtime function to be used in Complex
// multiplication and division when needed
cir::FuncOp libFunc;
{
mlir::OpBuilder::InsertionGuard ipGuard{builder};
builder.setInsertionPointToStart(pass.mlirModule.getBody());
libFunc = pass.buildRuntimeFunction(builder, libFuncName, loc, libFuncTy);
}
cir::CallOp call =
builder.createCallOp(loc, libFunc, {lhsReal, lhsImag, rhsReal, rhsImag});
return call.getResult();
}
static llvm::StringRef
getComplexDivLibCallName(llvm::APFloat::Semantics semantics) {
switch (semantics) {
case llvm::APFloat::S_IEEEhalf:
return "__divhc3";
case llvm::APFloat::S_IEEEsingle:
return "__divsc3";
case llvm::APFloat::S_IEEEdouble:
return "__divdc3";
case llvm::APFloat::S_PPCDoubleDouble:
return "__divtc3";
case llvm::APFloat::S_x87DoubleExtended:
return "__divxc3";
case llvm::APFloat::S_IEEEquad:
return "__divtc3";
default:
llvm_unreachable("unsupported floating point type");
}
}
static mlir::Value
buildAlgebraicComplexDiv(CIRBaseBuilderTy &builder, mlir::Location loc,
mlir::Value lhsReal, mlir::Value lhsImag,
mlir::Value rhsReal, mlir::Value rhsImag) {
// (a+bi) / (c+di) = ((ac+bd)/(cc+dd)) + ((bc-ad)/(cc+dd))i
mlir::Value &a = lhsReal;
mlir::Value &b = lhsImag;
mlir::Value &c = rhsReal;
mlir::Value &d = rhsImag;
mlir::Value ac = builder.createBinop(loc, a, cir::BinOpKind::Mul, c); // a*c
mlir::Value bd = builder.createBinop(loc, b, cir::BinOpKind::Mul, d); // b*d
mlir::Value cc = builder.createBinop(loc, c, cir::BinOpKind::Mul, c); // c*c
mlir::Value dd = builder.createBinop(loc, d, cir::BinOpKind::Mul, d); // d*d
mlir::Value acbd =
builder.createBinop(loc, ac, cir::BinOpKind::Add, bd); // ac+bd
mlir::Value ccdd =
builder.createBinop(loc, cc, cir::BinOpKind::Add, dd); // cc+dd
mlir::Value resultReal =
builder.createBinop(loc, acbd, cir::BinOpKind::Div, ccdd);
mlir::Value bc = builder.createBinop(loc, b, cir::BinOpKind::Mul, c); // b*c
mlir::Value ad = builder.createBinop(loc, a, cir::BinOpKind::Mul, d); // a*d
mlir::Value bcad =
builder.createBinop(loc, bc, cir::BinOpKind::Sub, ad); // bc-ad
mlir::Value resultImag =
builder.createBinop(loc, bcad, cir::BinOpKind::Div, ccdd);
return builder.createComplexCreate(loc, resultReal, resultImag);
}
static mlir::Value
buildRangeReductionComplexDiv(CIRBaseBuilderTy &builder, mlir::Location loc,
mlir::Value lhsReal, mlir::Value lhsImag,
mlir::Value rhsReal, mlir::Value rhsImag) {
// Implements Smith's algorithm for complex division.
// SMITH, R. L. Algorithm 116: Complex division. Commun. ACM 5, 8 (1962).
// Let:
// - lhs := a+bi
// - rhs := c+di
// - result := lhs / rhs = e+fi
//
// The algorithm pseudocode looks like follows:
// if fabs(c) >= fabs(d):
// r := d / c
// tmp := c + r*d
// e = (a + b*r) / tmp
// f = (b - a*r) / tmp
// else:
// r := c / d
// tmp := d + r*c
// e = (a*r + b) / tmp
// f = (b*r - a) / tmp
mlir::Value &a = lhsReal;
mlir::Value &b = lhsImag;
mlir::Value &c = rhsReal;
mlir::Value &d = rhsImag;
auto trueBranchBuilder = [&](mlir::OpBuilder &, mlir::Location) {
mlir::Value r = builder.createBinop(loc, d, cir::BinOpKind::Div,
c); // r := d / c
mlir::Value rd = builder.createBinop(loc, r, cir::BinOpKind::Mul, d); // r*d
mlir::Value tmp = builder.createBinop(loc, c, cir::BinOpKind::Add,
rd); // tmp := c + r*d
mlir::Value br = builder.createBinop(loc, b, cir::BinOpKind::Mul, r); // b*r
mlir::Value abr =
builder.createBinop(loc, a, cir::BinOpKind::Add, br); // a + b*r
mlir::Value e = builder.createBinop(loc, abr, cir::BinOpKind::Div, tmp);
mlir::Value ar = builder.createBinop(loc, a, cir::BinOpKind::Mul, r); // a*r
mlir::Value bar =
builder.createBinop(loc, b, cir::BinOpKind::Sub, ar); // b - a*r
mlir::Value f = builder.createBinop(loc, bar, cir::BinOpKind::Div, tmp);
mlir::Value result = builder.createComplexCreate(loc, e, f);
builder.createYield(loc, result);
};
auto falseBranchBuilder = [&](mlir::OpBuilder &, mlir::Location) {
mlir::Value r = builder.createBinop(loc, c, cir::BinOpKind::Div,
d); // r := c / d
mlir::Value rc = builder.createBinop(loc, r, cir::BinOpKind::Mul, c); // r*c
mlir::Value tmp = builder.createBinop(loc, d, cir::BinOpKind::Add,
rc); // tmp := d + r*c
mlir::Value ar = builder.createBinop(loc, a, cir::BinOpKind::Mul, r); // a*r
mlir::Value arb =
builder.createBinop(loc, ar, cir::BinOpKind::Add, b); // a*r + b
mlir::Value e = builder.createBinop(loc, arb, cir::BinOpKind::Div, tmp);
mlir::Value br = builder.createBinop(loc, b, cir::BinOpKind::Mul, r); // b*r
mlir::Value bra =
builder.createBinop(loc, br, cir::BinOpKind::Sub, a); // b*r - a
mlir::Value f = builder.createBinop(loc, bra, cir::BinOpKind::Div, tmp);
mlir::Value result = builder.createComplexCreate(loc, e, f);
builder.createYield(loc, result);
};
auto cFabs = builder.create<cir::FAbsOp>(loc, c);
auto dFabs = builder.create<cir::FAbsOp>(loc, d);
cir::CmpOp cmpResult =
builder.createCompare(loc, cir::CmpOpKind::ge, cFabs, dFabs);
auto ternary = builder.create<cir::TernaryOp>(
loc, cmpResult, trueBranchBuilder, falseBranchBuilder);
return ternary.getResult();
}
static mlir::Type higherPrecisionElementTypeForComplexArithmetic(
mlir::MLIRContext &context, clang::ASTContext &cc,
CIRBaseBuilderTy &builder, mlir::Type elementType) {
auto getHigherPrecisionFPType = [&context](mlir::Type type) -> mlir::Type {
if (mlir::isa<cir::FP16Type>(type))
return cir::SingleType::get(&context);
if (mlir::isa<cir::SingleType>(type) || mlir::isa<cir::BF16Type>(type))
return cir::DoubleType::get(&context);
if (mlir::isa<cir::DoubleType>(type))
return cir::LongDoubleType::get(&context, type);
return type;
};
auto getFloatTypeSemantics =
[&cc](mlir::Type type) -> const llvm::fltSemantics & {
const clang::TargetInfo &info = cc.getTargetInfo();
if (mlir::isa<cir::FP16Type>(type))
return info.getHalfFormat();
if (mlir::isa<cir::BF16Type>(type))
return info.getBFloat16Format();
if (mlir::isa<cir::SingleType>(type))
return info.getFloatFormat();
if (mlir::isa<cir::DoubleType>(type))
return info.getDoubleFormat();
if (mlir::isa<cir::LongDoubleType>(type)) {
if (cc.getLangOpts().OpenMP && cc.getLangOpts().OpenMPIsTargetDevice)
llvm_unreachable("NYI Float type semantics with OpenMP");
return info.getLongDoubleFormat();
}
if (mlir::isa<cir::FP128Type>(type)) {
if (cc.getLangOpts().OpenMP && cc.getLangOpts().OpenMPIsTargetDevice)
llvm_unreachable("NYI Float type semantics with OpenMP");
return info.getFloat128Format();
}
assert(false && "Unsupported float type semantics");
};
const mlir::Type higherElementType = getHigherPrecisionFPType(elementType);
const llvm::fltSemantics &elementTypeSemantics =
getFloatTypeSemantics(elementType);
const llvm::fltSemantics &higherElementTypeSemantics =
getFloatTypeSemantics(higherElementType);
// Check that the promoted type can handle the intermediate values without
// overflowing. This can be interpreted as:
// (SmallerType.LargestFiniteVal * SmallerType.LargestFiniteVal) * 2 <=
// LargerType.LargestFiniteVal.
// In terms of exponent it gives this formula:
// (SmallerType.LargestFiniteVal * SmallerType.LargestFiniteVal
// doubles the exponent of SmallerType.LargestFiniteVal)
if (llvm::APFloat::semanticsMaxExponent(elementTypeSemantics) * 2 + 1 <=
llvm::APFloat::semanticsMaxExponent(higherElementTypeSemantics)) {
return higherElementType;
}
// The intermediate values can't be represented in the promoted type
// without overflowing.
return {};
}
static mlir::Value
lowerComplexDiv(LoweringPreparePass &pass, CIRBaseBuilderTy &builder,
mlir::Location loc, cir::ComplexDivOp op, mlir::Value lhsReal,
mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag,
mlir::MLIRContext &mlirCx, clang::ASTContext &cc) {
cir::ComplexType complexTy = op.getType();
if (mlir::isa<cir::FPTypeInterface>(complexTy.getElementType())) {
cir::ComplexRangeKind range = op.getRange();
if (range == cir::ComplexRangeKind::Improved)
return buildRangeReductionComplexDiv(builder, loc, lhsReal, lhsImag,
rhsReal, rhsImag);
if (range == cir::ComplexRangeKind::Full)
return buildComplexBinOpLibCall(pass, builder, &getComplexDivLibCallName,
loc, complexTy, lhsReal, lhsImag, rhsReal,
rhsImag);
if (range == cir::ComplexRangeKind::Promoted) {
mlir::Type originalElementType = complexTy.getElementType();
mlir::Type higherPrecisionElementType =
higherPrecisionElementTypeForComplexArithmetic(mlirCx, cc, builder,
originalElementType);
if (!higherPrecisionElementType)
return buildRangeReductionComplexDiv(builder, loc, lhsReal, lhsImag,
rhsReal, rhsImag);
cir::CastKind floatingCastKind = cir::CastKind::floating;
lhsReal = builder.createCast(floatingCastKind, lhsReal,
higherPrecisionElementType);
lhsImag = builder.createCast(floatingCastKind, lhsImag,
higherPrecisionElementType);
rhsReal = builder.createCast(floatingCastKind, rhsReal,
higherPrecisionElementType);
rhsImag = builder.createCast(floatingCastKind, rhsImag,
higherPrecisionElementType);
mlir::Value algebraicResult = buildAlgebraicComplexDiv(
builder, loc, lhsReal, lhsImag, rhsReal, rhsImag);
mlir::Value resultReal = builder.createComplexReal(loc, algebraicResult);
mlir::Value resultImag = builder.createComplexImag(loc, algebraicResult);
mlir::Value finalReal =
builder.createCast(floatingCastKind, resultReal, originalElementType);
mlir::Value finalImag =
builder.createCast(floatingCastKind, resultImag, originalElementType);
return builder.createComplexCreate(loc, finalReal, finalImag);
}
}
return buildAlgebraicComplexDiv(builder, loc, lhsReal, lhsImag, rhsReal,
rhsImag);
}
void LoweringPreparePass::lowerComplexDivOp(cir::ComplexDivOp op) {
cir::CIRBaseBuilderTy builder(getContext());
builder.setInsertionPointAfter(op);
mlir::Location loc = op.getLoc();
mlir::TypedValue<cir::ComplexType> lhs = op.getLhs();
mlir::TypedValue<cir::ComplexType> rhs = op.getRhs();
mlir::Value lhsReal = builder.createComplexReal(loc, lhs);
mlir::Value lhsImag = builder.createComplexImag(loc, lhs);
mlir::Value rhsReal = builder.createComplexReal(loc, rhs);
mlir::Value rhsImag = builder.createComplexImag(loc, rhs);
mlir::Value loweredResult =
lowerComplexDiv(*this, builder, loc, op, lhsReal, lhsImag, rhsReal,
rhsImag, getContext(), *astCtx);
op.replaceAllUsesWith(loweredResult);
op.erase();
}
static llvm::StringRef
getComplexMulLibCallName(llvm::APFloat::Semantics semantics) {
switch (semantics) {
case llvm::APFloat::S_IEEEhalf:
return "__mulhc3";
case llvm::APFloat::S_IEEEsingle:
return "__mulsc3";
case llvm::APFloat::S_IEEEdouble:
return "__muldc3";
case llvm::APFloat::S_PPCDoubleDouble:
return "__multc3";
case llvm::APFloat::S_x87DoubleExtended:
return "__mulxc3";
case llvm::APFloat::S_IEEEquad:
return "__multc3";
default:
llvm_unreachable("unsupported floating point type");
}
}
static mlir::Value lowerComplexMul(LoweringPreparePass &pass,
CIRBaseBuilderTy &builder,
mlir::Location loc, cir::ComplexMulOp op,
mlir::Value lhsReal, mlir::Value lhsImag,
mlir::Value rhsReal, mlir::Value rhsImag) {
// (a+bi) * (c+di) = (ac-bd) + (ad+bc)i
mlir::Value resultRealLhs =
builder.createBinop(loc, lhsReal, cir::BinOpKind::Mul, rhsReal);
mlir::Value resultRealRhs =
builder.createBinop(loc, lhsImag, cir::BinOpKind::Mul, rhsImag);
mlir::Value resultImagLhs =
builder.createBinop(loc, lhsReal, cir::BinOpKind::Mul, rhsImag);
mlir::Value resultImagRhs =
builder.createBinop(loc, lhsImag, cir::BinOpKind::Mul, rhsReal);
mlir::Value resultReal = builder.createBinop(
loc, resultRealLhs, cir::BinOpKind::Sub, resultRealRhs);
mlir::Value resultImag = builder.createBinop(
loc, resultImagLhs, cir::BinOpKind::Add, resultImagRhs);
mlir::Value algebraicResult =
builder.createComplexCreate(loc, resultReal, resultImag);
cir::ComplexType complexTy = op.getType();
cir::ComplexRangeKind rangeKind = op.getRange();
if (mlir::isa<cir::IntType>(complexTy.getElementType()) ||
rangeKind == cir::ComplexRangeKind::Basic ||
rangeKind == cir::ComplexRangeKind::Improved ||
rangeKind == cir::ComplexRangeKind::Promoted)
return algebraicResult;
assert(!cir::MissingFeatures::fastMathFlags());
// Check whether the real part and the imaginary part of the result are both
// NaN. If so, emit a library call to compute the multiplication instead.
// We check a value against NaN by comparing the value against itself.
mlir::Value resultRealIsNaN = builder.createIsNaN(loc, resultReal);
mlir::Value resultImagIsNaN = builder.createIsNaN(loc, resultImag);
mlir::Value resultRealAndImagAreNaN =
builder.createLogicalAnd(loc, resultRealIsNaN, resultImagIsNaN);
return builder
.create<cir::TernaryOp>(
loc, resultRealAndImagAreNaN,
[&](mlir::OpBuilder &, mlir::Location) {
mlir::Value libCallResult = buildComplexBinOpLibCall(
pass, builder, &getComplexMulLibCallName, loc, complexTy,
lhsReal, lhsImag, rhsReal, rhsImag);
builder.createYield(loc, libCallResult);
},
[&](mlir::OpBuilder &, mlir::Location) {
builder.createYield(loc, algebraicResult);
})
.getResult();
}
void LoweringPreparePass::lowerComplexMulOp(cir::ComplexMulOp op) {
cir::CIRBaseBuilderTy builder(getContext());
builder.setInsertionPointAfter(op);
mlir::Location loc = op.getLoc();
mlir::TypedValue<cir::ComplexType> lhs = op.getLhs();
mlir::TypedValue<cir::ComplexType> rhs = op.getRhs();
mlir::Value lhsReal = builder.createComplexReal(loc, lhs);
mlir::Value lhsImag = builder.createComplexImag(loc, lhs);
mlir::Value rhsReal = builder.createComplexReal(loc, rhs);
mlir::Value rhsImag = builder.createComplexImag(loc, rhs);
mlir::Value loweredResult = lowerComplexMul(*this, builder, loc, op, lhsReal,
lhsImag, rhsReal, rhsImag);
op.replaceAllUsesWith(loweredResult);
op.erase();
}
void LoweringPreparePass::lowerUnaryOp(cir::UnaryOp op) {
mlir::Type ty = op.getType();
if (!mlir::isa<cir::ComplexType>(ty))
return;
mlir::Location loc = op.getLoc();
cir::UnaryOpKind opKind = op.getKind();
CIRBaseBuilderTy builder(getContext());
builder.setInsertionPointAfter(op);
mlir::Value operand = op.getInput();
mlir::Value operandReal = builder.createComplexReal(loc, operand);
mlir::Value operandImag = builder.createComplexImag(loc, operand);
mlir::Value resultReal;
mlir::Value resultImag;
switch (opKind) {
case cir::UnaryOpKind::Inc:
case cir::UnaryOpKind::Dec:
resultReal = builder.createUnaryOp(loc, opKind, operandReal);
resultImag = operandImag;
break;
case cir::UnaryOpKind::Plus:
case cir::UnaryOpKind::Minus:
resultReal = builder.createUnaryOp(loc, opKind, operandReal);
resultImag = builder.createUnaryOp(loc, opKind, operandImag);
break;
case cir::UnaryOpKind::Not:
resultReal = operandReal;
resultImag =
builder.createUnaryOp(loc, cir::UnaryOpKind::Minus, operandImag);
break;
}
mlir::Value result = builder.createComplexCreate(loc, resultReal, resultImag);
op.replaceAllUsesWith(result);
op.erase();
}
cir::FuncOp
LoweringPreparePass::buildCXXGlobalVarDeclInitFunc(cir::GlobalOp op) {
// TODO(cir): Store this in the GlobalOp.
// This should come from the MangleContext, but for now I'm hardcoding it.
SmallString<256> fnName("__cxx_global_var_init");
// Get a unique name
uint32_t cnt = dynamicInitializerNames[fnName]++;
if (cnt)
fnName += "." + llvm::Twine(cnt).str();
// Create a variable initialization function.
CIRBaseBuilderTy builder(getContext());
builder.setInsertionPointAfter(op);
cir::VoidType voidTy = builder.getVoidTy();
auto fnType = cir::FuncType::get({}, voidTy);
FuncOp f = buildRuntimeFunction(builder, fnName, op.getLoc(), fnType,
cir::GlobalLinkageKind::InternalLinkage);
// Move over the initialzation code of the ctor region.
mlir::Block *entryBB = f.addEntryBlock();
if (!op.getCtorRegion().empty()) {
mlir::Block &block = op.getCtorRegion().front();
entryBB->getOperations().splice(entryBB->begin(), block.getOperations(),
block.begin(), std::prev(block.end()));
}
// Register the destructor call with __cxa_atexit
mlir::Region &dtorRegion = op.getDtorRegion();
if (!dtorRegion.empty()) {
assert(!cir::MissingFeatures::astVarDeclInterface());
assert(!cir::MissingFeatures::opGlobalThreadLocal());
// Create a variable that binds the atexit to this shared object.
builder.setInsertionPointToStart(&mlirModule.getBodyRegion().front());
cir::GlobalOp handle = buildRuntimeVariable(
builder, "__dso_handle", op.getLoc(), builder.getI8Type(),
cir::GlobalLinkageKind::ExternalLinkage, cir::VisibilityKind::Hidden);
// Look for the destructor call in dtorBlock
mlir::Block &dtorBlock = dtorRegion.front();
cir::CallOp dtorCall;
for (auto op : reverse(dtorBlock.getOps<cir::CallOp>())) {
dtorCall = op;
break;
}
assert(dtorCall && "Expected a dtor call");
cir::FuncOp dtorFunc = getCalledFunction(dtorCall);
assert(dtorFunc && "Expected a dtor call");
// Create a runtime helper function:
// extern "C" int __cxa_atexit(void (*f)(void *), void *p, void *d);
auto voidPtrTy = cir::PointerType::get(voidTy);
auto voidFnTy = cir::FuncType::get({voidPtrTy}, voidTy);
auto voidFnPtrTy = cir::PointerType::get(voidFnTy);
auto handlePtrTy = cir::PointerType::get(handle.getSymType());
auto fnAtExitType =
cir::FuncType::get({voidFnPtrTy, voidPtrTy, handlePtrTy}, voidTy);
const char *nameAtExit = "__cxa_atexit";
cir::FuncOp fnAtExit =
buildRuntimeFunction(builder, nameAtExit, op.getLoc(), fnAtExitType);
// Replace the dtor call with a call to __cxa_atexit(&dtor, &var,
// &__dso_handle)
builder.setInsertionPointAfter(dtorCall);
mlir::Value args[3];
auto dtorPtrTy = cir::PointerType::get(dtorFunc.getFunctionType());
// dtorPtrTy
args[0] = cir::GetGlobalOp::create(builder, dtorCall.getLoc(), dtorPtrTy,
dtorFunc.getSymName());
args[0] = cir::CastOp::create(builder, dtorCall.getLoc(), voidFnPtrTy,
cir::CastKind::bitcast, args[0]);
args[1] =
cir::CastOp::create(builder, dtorCall.getLoc(), voidPtrTy,
cir::CastKind::bitcast, dtorCall.getArgOperand(0));
args[2] = cir::GetGlobalOp::create(builder, handle.getLoc(), handlePtrTy,
handle.getSymName());
builder.createCallOp(dtorCall.getLoc(), fnAtExit, args);
dtorCall->erase();
entryBB->getOperations().splice(entryBB->end(), dtorBlock.getOperations(),
dtorBlock.begin(),
std::prev(dtorBlock.end()));
}
// Replace cir.yield with cir.return
builder.setInsertionPointToEnd(entryBB);
mlir::Operation *yieldOp = nullptr;
if (!op.getCtorRegion().empty()) {
mlir::Block &block = op.getCtorRegion().front();
yieldOp = &block.getOperations().back();
} else {
assert(!dtorRegion.empty());
mlir::Block &block = dtorRegion.front();
yieldOp = &block.getOperations().back();
}
assert(isa<cir::YieldOp>(*yieldOp));
cir::ReturnOp::create(builder, yieldOp->getLoc());
return f;
}
void LoweringPreparePass::lowerGlobalOp(GlobalOp op) {
mlir::Region &ctorRegion = op.getCtorRegion();
mlir::Region &dtorRegion = op.getDtorRegion();
if (!ctorRegion.empty() || !dtorRegion.empty()) {
// Build a variable initialization function and move the initialzation code
// in the ctor region over.
cir::FuncOp f = buildCXXGlobalVarDeclInitFunc(op);
// Clear the ctor and dtor region
ctorRegion.getBlocks().clear();
dtorRegion.getBlocks().clear();
assert(!cir::MissingFeatures::astVarDeclInterface());
dynamicInitializers.push_back(f);
}
assert(!cir::MissingFeatures::opGlobalAnnotations());
}
template <typename AttributeTy>
static llvm::SmallVector<mlir::Attribute>
prepareCtorDtorAttrList(mlir::MLIRContext *context,
llvm::ArrayRef<std::pair<std::string, uint32_t>> list) {
llvm::SmallVector<mlir::Attribute> attrs;
for (const auto &[name, priority] : list)
attrs.push_back(AttributeTy::get(context, name, priority));
return attrs;
}
void LoweringPreparePass::buildGlobalCtorDtorList() {
if (!globalCtorList.empty()) {
llvm::SmallVector<mlir::Attribute> globalCtors =
prepareCtorDtorAttrList<cir::GlobalCtorAttr>(&getContext(),
globalCtorList);
mlirModule->setAttr(cir::CIRDialect::getGlobalCtorsAttrName(),
mlir::ArrayAttr::get(&getContext(), globalCtors));
}
// We will eventual need to populate a global_dtor list, but that's not
// needed for globals with destructors. It will only be needed for functions
// that are marked as global destructors with an attribute.
assert(!cir::MissingFeatures::opGlobalDtorList());
}
void LoweringPreparePass::buildCXXGlobalInitFunc() {
if (dynamicInitializers.empty())
return;
// TODO: handle globals with a user-specified initialzation priority.
// TODO: handle default priority more nicely.
assert(!cir::MissingFeatures::opGlobalCtorPriority());
SmallString<256> fnName;
// Include the filename in the symbol name. Including "sub_" matches gcc
// and makes sure these symbols appear lexicographically behind the symbols
// with priority (TBD). Module implementation units behave the same
// way as a non-modular TU with imports.
// TODO: check CXX20ModuleInits
if (astCtx->getCurrentNamedModule() &&
!astCtx->getCurrentNamedModule()->isModuleImplementation()) {
llvm::raw_svector_ostream out(fnName);
std::unique_ptr<clang::MangleContext> mangleCtx(
astCtx->createMangleContext());
cast<clang::ItaniumMangleContext>(*mangleCtx)
.mangleModuleInitializer(astCtx->getCurrentNamedModule(), out);
} else {
fnName += "_GLOBAL__sub_I_";
fnName += getTransformedFileName(mlirModule);
}
CIRBaseBuilderTy builder(getContext());
builder.setInsertionPointToEnd(&mlirModule.getBodyRegion().back());
auto fnType = cir::FuncType::get({}, builder.getVoidTy());
cir::FuncOp f =
buildRuntimeFunction(builder, fnName, mlirModule.getLoc(), fnType,
cir::GlobalLinkageKind::ExternalLinkage);
builder.setInsertionPointToStart(f.addEntryBlock());
for (cir::FuncOp &f : dynamicInitializers)
builder.createCallOp(f.getLoc(), f, {});
// Add the global init function (not the individual ctor functions) to the
// global ctor list.
globalCtorList.emplace_back(fnName,
cir::GlobalCtorAttr::getDefaultPriority());
cir::ReturnOp::create(builder, f.getLoc());
}
static void lowerArrayDtorCtorIntoLoop(cir::CIRBaseBuilderTy &builder,
clang::ASTContext *astCtx,
mlir::Operation *op, mlir::Type eltTy,
mlir::Value arrayAddr, uint64_t arrayLen,
bool isCtor) {
// Generate loop to call into ctor/dtor for every element.
mlir::Location loc = op->getLoc();
// TODO: instead of getting the size from the AST context, create alias for
// PtrDiffTy and unify with CIRGen stuff.
const unsigned sizeTypeSize =
astCtx->getTypeSize(astCtx->getSignedSizeType());
uint64_t endOffset = isCtor ? arrayLen : arrayLen - 1;
mlir::Value endOffsetVal =
builder.getUnsignedInt(loc, endOffset, sizeTypeSize);
auto begin = cir::CastOp::create(builder, loc, eltTy,
cir::CastKind::array_to_ptrdecay, arrayAddr);
mlir::Value end =
cir::PtrStrideOp::create(builder, loc, eltTy, begin, endOffsetVal);
mlir::Value start = isCtor ? begin : end;
mlir::Value stop = isCtor ? end : begin;
mlir::Value tmpAddr = builder.createAlloca(
loc, /*addr type*/ builder.getPointerTo(eltTy),
/*var type*/ eltTy, "__array_idx", builder.getAlignmentAttr(1));
builder.createStore(loc, start, tmpAddr);
cir::DoWhileOp loop = builder.createDoWhile(
loc,
/*condBuilder=*/
[&](mlir::OpBuilder &b, mlir::Location loc) {
auto currentElement = b.create<cir::LoadOp>(loc, eltTy, tmpAddr);
mlir::Type boolTy = cir::BoolType::get(b.getContext());
auto cmp = builder.create<cir::CmpOp>(loc, boolTy, cir::CmpOpKind::ne,
currentElement, stop);
builder.createCondition(cmp);
},
/*bodyBuilder=*/
[&](mlir::OpBuilder &b, mlir::Location loc) {
auto currentElement = b.create<cir::LoadOp>(loc, eltTy, tmpAddr);
cir::CallOp ctorCall;
op->walk([&](cir::CallOp c) { ctorCall = c; });
assert(ctorCall && "expected ctor call");
// Array elements get constructed in order but destructed in reverse.
mlir::Value stride;
if (isCtor)
stride = builder.getUnsignedInt(loc, 1, sizeTypeSize);
else
stride = builder.getSignedInt(loc, -1, sizeTypeSize);
ctorCall->moveBefore(stride.getDefiningOp());
ctorCall->setOperand(0, currentElement);
auto nextElement = cir::PtrStrideOp::create(builder, loc, eltTy,
currentElement, stride);
// Store the element pointer to the temporary variable
builder.createStore(loc, nextElement, tmpAddr);
builder.createYield(loc);
});
op->replaceAllUsesWith(loop);
op->erase();
}
void LoweringPreparePass::lowerArrayDtor(cir::ArrayDtor op) {
CIRBaseBuilderTy builder(getContext());
builder.setInsertionPointAfter(op.getOperation());
mlir::Type eltTy = op->getRegion(0).getArgument(0).getType();
assert(!cir::MissingFeatures::vlas());
auto arrayLen =
mlir::cast<cir::ArrayType>(op.getAddr().getType().getPointee()).getSize();
lowerArrayDtorCtorIntoLoop(builder, astCtx, op, eltTy, op.getAddr(), arrayLen,
false);
}
void LoweringPreparePass::lowerArrayCtor(cir::ArrayCtor op) {
cir::CIRBaseBuilderTy builder(getContext());
builder.setInsertionPointAfter(op.getOperation());
mlir::Type eltTy = op->getRegion(0).getArgument(0).getType();
assert(!cir::MissingFeatures::vlas());
auto arrayLen =
mlir::cast<cir::ArrayType>(op.getAddr().getType().getPointee()).getSize();
lowerArrayDtorCtorIntoLoop(builder, astCtx, op, eltTy, op.getAddr(), arrayLen,
true);
}
void LoweringPreparePass::runOnOp(mlir::Operation *op) {
if (auto arrayCtor = dyn_cast<ArrayCtor>(op))
lowerArrayCtor(arrayCtor);
else if (auto arrayDtor = dyn_cast<cir::ArrayDtor>(op))
lowerArrayDtor(arrayDtor);
else if (auto cast = mlir::dyn_cast<cir::CastOp>(op))
lowerCastOp(cast);
else if (auto complexDiv = mlir::dyn_cast<cir::ComplexDivOp>(op))
lowerComplexDivOp(complexDiv);
else if (auto complexMul = mlir::dyn_cast<cir::ComplexMulOp>(op))
lowerComplexMulOp(complexMul);
else if (auto glob = mlir::dyn_cast<cir::GlobalOp>(op))
lowerGlobalOp(glob);
else if (auto unary = mlir::dyn_cast<cir::UnaryOp>(op))
lowerUnaryOp(unary);
}
void LoweringPreparePass::runOnOperation() {
mlir::Operation *op = getOperation();
if (isa<::mlir::ModuleOp>(op))
mlirModule = cast<::mlir::ModuleOp>(op);
llvm::SmallVector<mlir::Operation *> opsToTransform;
op->walk([&](mlir::Operation *op) {
if (mlir::isa<cir::ArrayCtor, cir::ArrayDtor, cir::CastOp,
cir::ComplexMulOp, cir::ComplexDivOp, cir::GlobalOp,
cir::UnaryOp>(op))
opsToTransform.push_back(op);
});
for (mlir::Operation *o : opsToTransform)
runOnOp(o);
buildCXXGlobalInitFunc();
buildGlobalCtorDtorList();
}
std::unique_ptr<Pass> mlir::createLoweringPreparePass() {
return std::make_unique<LoweringPreparePass>();
}
std::unique_ptr<Pass>
mlir::createLoweringPreparePass(clang::ASTContext *astCtx) {
auto pass = std::make_unique<LoweringPreparePass>();
pass->setASTContext(astCtx);
return std::move(pass);
}