| //===- LoweringPrepare.cpp - pareparation work for LLVM lowering ----------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "PassDetail.h" |
| #include "clang/AST/ASTContext.h" |
| #include "clang/Basic/Module.h" |
| #include "clang/Basic/TargetInfo.h" |
| #include "clang/CIR/Dialect/Builder/CIRBaseBuilder.h" |
| #include "clang/CIR/Dialect/IR/CIRDialect.h" |
| #include "clang/CIR/Dialect/IR/CIROpsEnums.h" |
| #include "clang/CIR/Dialect/Passes.h" |
| #include "clang/CIR/MissingFeatures.h" |
| #include "llvm/Support/Path.h" |
| |
| #include <memory> |
| |
| using namespace mlir; |
| using namespace cir; |
| |
| static SmallString<128> getTransformedFileName(mlir::ModuleOp mlirModule) { |
| SmallString<128> fileName; |
| |
| if (mlirModule.getSymName()) |
| fileName = llvm::sys::path::filename(mlirModule.getSymName()->str()); |
| |
| if (fileName.empty()) |
| fileName = "<null>"; |
| |
| for (size_t i = 0; i < fileName.size(); ++i) { |
| // Replace everything that's not [a-zA-Z0-9._] with a _. This set happens |
| // to be the set of C preprocessing numbers. |
| if (!clang::isPreprocessingNumberBody(fileName[i])) |
| fileName[i] = '_'; |
| } |
| |
| return fileName; |
| } |
| |
| /// Return the FuncOp called by `callOp`. |
| static cir::FuncOp getCalledFunction(cir::CallOp callOp) { |
| mlir::SymbolRefAttr sym = llvm::dyn_cast_if_present<mlir::SymbolRefAttr>( |
| callOp.getCallableForCallee()); |
| if (!sym) |
| return nullptr; |
| return dyn_cast_or_null<cir::FuncOp>( |
| mlir::SymbolTable::lookupNearestSymbolFrom(callOp, sym)); |
| } |
| |
| namespace { |
| struct LoweringPreparePass : public LoweringPrepareBase<LoweringPreparePass> { |
| LoweringPreparePass() = default; |
| void runOnOperation() override; |
| |
| void runOnOp(mlir::Operation *op); |
| void lowerCastOp(cir::CastOp op); |
| void lowerComplexDivOp(cir::ComplexDivOp op); |
| void lowerComplexMulOp(cir::ComplexMulOp op); |
| void lowerUnaryOp(cir::UnaryOp op); |
| void lowerGlobalOp(cir::GlobalOp op); |
| void lowerArrayDtor(cir::ArrayDtor op); |
| void lowerArrayCtor(cir::ArrayCtor op); |
| |
| /// Build the function that initializes the specified global |
| cir::FuncOp buildCXXGlobalVarDeclInitFunc(cir::GlobalOp op); |
| |
| /// Build a module init function that calls all the dynamic initializers. |
| void buildCXXGlobalInitFunc(); |
| |
| /// Materialize global ctor/dtor list |
| void buildGlobalCtorDtorList(); |
| |
| cir::FuncOp buildRuntimeFunction( |
| mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc, |
| cir::FuncType type, |
| cir::GlobalLinkageKind linkage = cir::GlobalLinkageKind::ExternalLinkage); |
| |
| cir::GlobalOp buildRuntimeVariable( |
| mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc, |
| mlir::Type type, |
| cir::GlobalLinkageKind linkage = cir::GlobalLinkageKind::ExternalLinkage, |
| cir::VisibilityKind visibility = cir::VisibilityKind::Default); |
| |
| /// |
| /// AST related |
| /// ----------- |
| |
| clang::ASTContext *astCtx; |
| |
| /// Tracks current module. |
| mlir::ModuleOp mlirModule; |
| |
| /// Tracks existing dynamic initializers. |
| llvm::StringMap<uint32_t> dynamicInitializerNames; |
| llvm::SmallVector<cir::FuncOp> dynamicInitializers; |
| |
| /// List of ctors and their priorities to be called before main() |
| llvm::SmallVector<std::pair<std::string, uint32_t>, 4> globalCtorList; |
| |
| void setASTContext(clang::ASTContext *c) { astCtx = c; } |
| }; |
| |
| } // namespace |
| |
| cir::GlobalOp LoweringPreparePass::buildRuntimeVariable( |
| mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc, |
| mlir::Type type, cir::GlobalLinkageKind linkage, |
| cir::VisibilityKind visibility) { |
| cir::GlobalOp g = dyn_cast_or_null<cir::GlobalOp>( |
| mlir::SymbolTable::lookupNearestSymbolFrom( |
| mlirModule, mlir::StringAttr::get(mlirModule->getContext(), name))); |
| if (!g) { |
| g = cir::GlobalOp::create(builder, loc, name, type); |
| g.setLinkageAttr( |
| cir::GlobalLinkageKindAttr::get(builder.getContext(), linkage)); |
| mlir::SymbolTable::setSymbolVisibility( |
| g, mlir::SymbolTable::Visibility::Private); |
| g.setGlobalVisibilityAttr( |
| cir::VisibilityAttr::get(builder.getContext(), visibility)); |
| } |
| return g; |
| } |
| |
| cir::FuncOp LoweringPreparePass::buildRuntimeFunction( |
| mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc, |
| cir::FuncType type, cir::GlobalLinkageKind linkage) { |
| cir::FuncOp f = dyn_cast_or_null<FuncOp>(SymbolTable::lookupNearestSymbolFrom( |
| mlirModule, StringAttr::get(mlirModule->getContext(), name))); |
| if (!f) { |
| f = builder.create<cir::FuncOp>(loc, name, type); |
| f.setLinkageAttr( |
| cir::GlobalLinkageKindAttr::get(builder.getContext(), linkage)); |
| mlir::SymbolTable::setSymbolVisibility( |
| f, mlir::SymbolTable::Visibility::Private); |
| |
| assert(!cir::MissingFeatures::opFuncExtraAttrs()); |
| } |
| return f; |
| } |
| |
| static mlir::Value lowerScalarToComplexCast(mlir::MLIRContext &ctx, |
| cir::CastOp op) { |
| cir::CIRBaseBuilderTy builder(ctx); |
| builder.setInsertionPoint(op); |
| |
| mlir::Value src = op.getSrc(); |
| mlir::Value imag = builder.getNullValue(src.getType(), op.getLoc()); |
| return builder.createComplexCreate(op.getLoc(), src, imag); |
| } |
| |
| static mlir::Value lowerComplexToScalarCast(mlir::MLIRContext &ctx, |
| cir::CastOp op, |
| cir::CastKind elemToBoolKind) { |
| cir::CIRBaseBuilderTy builder(ctx); |
| builder.setInsertionPoint(op); |
| |
| mlir::Value src = op.getSrc(); |
| if (!mlir::isa<cir::BoolType>(op.getType())) |
| return builder.createComplexReal(op.getLoc(), src); |
| |
| // Complex cast to bool: (bool)(a+bi) => (bool)a || (bool)b |
| mlir::Value srcReal = builder.createComplexReal(op.getLoc(), src); |
| mlir::Value srcImag = builder.createComplexImag(op.getLoc(), src); |
| |
| cir::BoolType boolTy = builder.getBoolTy(); |
| mlir::Value srcRealToBool = |
| builder.createCast(op.getLoc(), elemToBoolKind, srcReal, boolTy); |
| mlir::Value srcImagToBool = |
| builder.createCast(op.getLoc(), elemToBoolKind, srcImag, boolTy); |
| return builder.createLogicalOr(op.getLoc(), srcRealToBool, srcImagToBool); |
| } |
| |
| static mlir::Value lowerComplexToComplexCast(mlir::MLIRContext &ctx, |
| cir::CastOp op, |
| cir::CastKind scalarCastKind) { |
| CIRBaseBuilderTy builder(ctx); |
| builder.setInsertionPoint(op); |
| |
| mlir::Value src = op.getSrc(); |
| auto dstComplexElemTy = |
| mlir::cast<cir::ComplexType>(op.getType()).getElementType(); |
| |
| mlir::Value srcReal = builder.createComplexReal(op.getLoc(), src); |
| mlir::Value srcImag = builder.createComplexImag(op.getLoc(), src); |
| |
| mlir::Value dstReal = builder.createCast(op.getLoc(), scalarCastKind, srcReal, |
| dstComplexElemTy); |
| mlir::Value dstImag = builder.createCast(op.getLoc(), scalarCastKind, srcImag, |
| dstComplexElemTy); |
| return builder.createComplexCreate(op.getLoc(), dstReal, dstImag); |
| } |
| |
| void LoweringPreparePass::lowerCastOp(cir::CastOp op) { |
| mlir::MLIRContext &ctx = getContext(); |
| mlir::Value loweredValue = [&]() -> mlir::Value { |
| switch (op.getKind()) { |
| case cir::CastKind::float_to_complex: |
| case cir::CastKind::int_to_complex: |
| return lowerScalarToComplexCast(ctx, op); |
| case cir::CastKind::float_complex_to_real: |
| case cir::CastKind::int_complex_to_real: |
| return lowerComplexToScalarCast(ctx, op, op.getKind()); |
| case cir::CastKind::float_complex_to_bool: |
| return lowerComplexToScalarCast(ctx, op, cir::CastKind::float_to_bool); |
| case cir::CastKind::int_complex_to_bool: |
| return lowerComplexToScalarCast(ctx, op, cir::CastKind::int_to_bool); |
| case cir::CastKind::float_complex: |
| return lowerComplexToComplexCast(ctx, op, cir::CastKind::floating); |
| case cir::CastKind::float_complex_to_int_complex: |
| return lowerComplexToComplexCast(ctx, op, cir::CastKind::float_to_int); |
| case cir::CastKind::int_complex: |
| return lowerComplexToComplexCast(ctx, op, cir::CastKind::integral); |
| case cir::CastKind::int_complex_to_float_complex: |
| return lowerComplexToComplexCast(ctx, op, cir::CastKind::int_to_float); |
| default: |
| return nullptr; |
| } |
| }(); |
| |
| if (loweredValue) { |
| op.replaceAllUsesWith(loweredValue); |
| op.erase(); |
| } |
| } |
| |
| static mlir::Value buildComplexBinOpLibCall( |
| LoweringPreparePass &pass, CIRBaseBuilderTy &builder, |
| llvm::StringRef (*libFuncNameGetter)(llvm::APFloat::Semantics), |
| mlir::Location loc, cir::ComplexType ty, mlir::Value lhsReal, |
| mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag) { |
| cir::FPTypeInterface elementTy = |
| mlir::cast<cir::FPTypeInterface>(ty.getElementType()); |
| |
| llvm::StringRef libFuncName = libFuncNameGetter( |
| llvm::APFloat::SemanticsToEnum(elementTy.getFloatSemantics())); |
| llvm::SmallVector<mlir::Type, 4> libFuncInputTypes(4, elementTy); |
| |
| cir::FuncType libFuncTy = cir::FuncType::get(libFuncInputTypes, ty); |
| |
| // Insert a declaration for the runtime function to be used in Complex |
| // multiplication and division when needed |
| cir::FuncOp libFunc; |
| { |
| mlir::OpBuilder::InsertionGuard ipGuard{builder}; |
| builder.setInsertionPointToStart(pass.mlirModule.getBody()); |
| libFunc = pass.buildRuntimeFunction(builder, libFuncName, loc, libFuncTy); |
| } |
| |
| cir::CallOp call = |
| builder.createCallOp(loc, libFunc, {lhsReal, lhsImag, rhsReal, rhsImag}); |
| return call.getResult(); |
| } |
| |
| static llvm::StringRef |
| getComplexDivLibCallName(llvm::APFloat::Semantics semantics) { |
| switch (semantics) { |
| case llvm::APFloat::S_IEEEhalf: |
| return "__divhc3"; |
| case llvm::APFloat::S_IEEEsingle: |
| return "__divsc3"; |
| case llvm::APFloat::S_IEEEdouble: |
| return "__divdc3"; |
| case llvm::APFloat::S_PPCDoubleDouble: |
| return "__divtc3"; |
| case llvm::APFloat::S_x87DoubleExtended: |
| return "__divxc3"; |
| case llvm::APFloat::S_IEEEquad: |
| return "__divtc3"; |
| default: |
| llvm_unreachable("unsupported floating point type"); |
| } |
| } |
| |
| static mlir::Value |
| buildAlgebraicComplexDiv(CIRBaseBuilderTy &builder, mlir::Location loc, |
| mlir::Value lhsReal, mlir::Value lhsImag, |
| mlir::Value rhsReal, mlir::Value rhsImag) { |
| // (a+bi) / (c+di) = ((ac+bd)/(cc+dd)) + ((bc-ad)/(cc+dd))i |
| mlir::Value &a = lhsReal; |
| mlir::Value &b = lhsImag; |
| mlir::Value &c = rhsReal; |
| mlir::Value &d = rhsImag; |
| |
| mlir::Value ac = builder.createBinop(loc, a, cir::BinOpKind::Mul, c); // a*c |
| mlir::Value bd = builder.createBinop(loc, b, cir::BinOpKind::Mul, d); // b*d |
| mlir::Value cc = builder.createBinop(loc, c, cir::BinOpKind::Mul, c); // c*c |
| mlir::Value dd = builder.createBinop(loc, d, cir::BinOpKind::Mul, d); // d*d |
| mlir::Value acbd = |
| builder.createBinop(loc, ac, cir::BinOpKind::Add, bd); // ac+bd |
| mlir::Value ccdd = |
| builder.createBinop(loc, cc, cir::BinOpKind::Add, dd); // cc+dd |
| mlir::Value resultReal = |
| builder.createBinop(loc, acbd, cir::BinOpKind::Div, ccdd); |
| |
| mlir::Value bc = builder.createBinop(loc, b, cir::BinOpKind::Mul, c); // b*c |
| mlir::Value ad = builder.createBinop(loc, a, cir::BinOpKind::Mul, d); // a*d |
| mlir::Value bcad = |
| builder.createBinop(loc, bc, cir::BinOpKind::Sub, ad); // bc-ad |
| mlir::Value resultImag = |
| builder.createBinop(loc, bcad, cir::BinOpKind::Div, ccdd); |
| return builder.createComplexCreate(loc, resultReal, resultImag); |
| } |
| |
| static mlir::Value |
| buildRangeReductionComplexDiv(CIRBaseBuilderTy &builder, mlir::Location loc, |
| mlir::Value lhsReal, mlir::Value lhsImag, |
| mlir::Value rhsReal, mlir::Value rhsImag) { |
| // Implements Smith's algorithm for complex division. |
| // SMITH, R. L. Algorithm 116: Complex division. Commun. ACM 5, 8 (1962). |
| |
| // Let: |
| // - lhs := a+bi |
| // - rhs := c+di |
| // - result := lhs / rhs = e+fi |
| // |
| // The algorithm pseudocode looks like follows: |
| // if fabs(c) >= fabs(d): |
| // r := d / c |
| // tmp := c + r*d |
| // e = (a + b*r) / tmp |
| // f = (b - a*r) / tmp |
| // else: |
| // r := c / d |
| // tmp := d + r*c |
| // e = (a*r + b) / tmp |
| // f = (b*r - a) / tmp |
| |
| mlir::Value &a = lhsReal; |
| mlir::Value &b = lhsImag; |
| mlir::Value &c = rhsReal; |
| mlir::Value &d = rhsImag; |
| |
| auto trueBranchBuilder = [&](mlir::OpBuilder &, mlir::Location) { |
| mlir::Value r = builder.createBinop(loc, d, cir::BinOpKind::Div, |
| c); // r := d / c |
| mlir::Value rd = builder.createBinop(loc, r, cir::BinOpKind::Mul, d); // r*d |
| mlir::Value tmp = builder.createBinop(loc, c, cir::BinOpKind::Add, |
| rd); // tmp := c + r*d |
| |
| mlir::Value br = builder.createBinop(loc, b, cir::BinOpKind::Mul, r); // b*r |
| mlir::Value abr = |
| builder.createBinop(loc, a, cir::BinOpKind::Add, br); // a + b*r |
| mlir::Value e = builder.createBinop(loc, abr, cir::BinOpKind::Div, tmp); |
| |
| mlir::Value ar = builder.createBinop(loc, a, cir::BinOpKind::Mul, r); // a*r |
| mlir::Value bar = |
| builder.createBinop(loc, b, cir::BinOpKind::Sub, ar); // b - a*r |
| mlir::Value f = builder.createBinop(loc, bar, cir::BinOpKind::Div, tmp); |
| |
| mlir::Value result = builder.createComplexCreate(loc, e, f); |
| builder.createYield(loc, result); |
| }; |
| |
| auto falseBranchBuilder = [&](mlir::OpBuilder &, mlir::Location) { |
| mlir::Value r = builder.createBinop(loc, c, cir::BinOpKind::Div, |
| d); // r := c / d |
| mlir::Value rc = builder.createBinop(loc, r, cir::BinOpKind::Mul, c); // r*c |
| mlir::Value tmp = builder.createBinop(loc, d, cir::BinOpKind::Add, |
| rc); // tmp := d + r*c |
| |
| mlir::Value ar = builder.createBinop(loc, a, cir::BinOpKind::Mul, r); // a*r |
| mlir::Value arb = |
| builder.createBinop(loc, ar, cir::BinOpKind::Add, b); // a*r + b |
| mlir::Value e = builder.createBinop(loc, arb, cir::BinOpKind::Div, tmp); |
| |
| mlir::Value br = builder.createBinop(loc, b, cir::BinOpKind::Mul, r); // b*r |
| mlir::Value bra = |
| builder.createBinop(loc, br, cir::BinOpKind::Sub, a); // b*r - a |
| mlir::Value f = builder.createBinop(loc, bra, cir::BinOpKind::Div, tmp); |
| |
| mlir::Value result = builder.createComplexCreate(loc, e, f); |
| builder.createYield(loc, result); |
| }; |
| |
| auto cFabs = builder.create<cir::FAbsOp>(loc, c); |
| auto dFabs = builder.create<cir::FAbsOp>(loc, d); |
| cir::CmpOp cmpResult = |
| builder.createCompare(loc, cir::CmpOpKind::ge, cFabs, dFabs); |
| auto ternary = builder.create<cir::TernaryOp>( |
| loc, cmpResult, trueBranchBuilder, falseBranchBuilder); |
| |
| return ternary.getResult(); |
| } |
| |
| static mlir::Type higherPrecisionElementTypeForComplexArithmetic( |
| mlir::MLIRContext &context, clang::ASTContext &cc, |
| CIRBaseBuilderTy &builder, mlir::Type elementType) { |
| |
| auto getHigherPrecisionFPType = [&context](mlir::Type type) -> mlir::Type { |
| if (mlir::isa<cir::FP16Type>(type)) |
| return cir::SingleType::get(&context); |
| |
| if (mlir::isa<cir::SingleType>(type) || mlir::isa<cir::BF16Type>(type)) |
| return cir::DoubleType::get(&context); |
| |
| if (mlir::isa<cir::DoubleType>(type)) |
| return cir::LongDoubleType::get(&context, type); |
| |
| return type; |
| }; |
| |
| auto getFloatTypeSemantics = |
| [&cc](mlir::Type type) -> const llvm::fltSemantics & { |
| const clang::TargetInfo &info = cc.getTargetInfo(); |
| if (mlir::isa<cir::FP16Type>(type)) |
| return info.getHalfFormat(); |
| |
| if (mlir::isa<cir::BF16Type>(type)) |
| return info.getBFloat16Format(); |
| |
| if (mlir::isa<cir::SingleType>(type)) |
| return info.getFloatFormat(); |
| |
| if (mlir::isa<cir::DoubleType>(type)) |
| return info.getDoubleFormat(); |
| |
| if (mlir::isa<cir::LongDoubleType>(type)) { |
| if (cc.getLangOpts().OpenMP && cc.getLangOpts().OpenMPIsTargetDevice) |
| llvm_unreachable("NYI Float type semantics with OpenMP"); |
| return info.getLongDoubleFormat(); |
| } |
| |
| if (mlir::isa<cir::FP128Type>(type)) { |
| if (cc.getLangOpts().OpenMP && cc.getLangOpts().OpenMPIsTargetDevice) |
| llvm_unreachable("NYI Float type semantics with OpenMP"); |
| return info.getFloat128Format(); |
| } |
| |
| assert(false && "Unsupported float type semantics"); |
| }; |
| |
| const mlir::Type higherElementType = getHigherPrecisionFPType(elementType); |
| const llvm::fltSemantics &elementTypeSemantics = |
| getFloatTypeSemantics(elementType); |
| const llvm::fltSemantics &higherElementTypeSemantics = |
| getFloatTypeSemantics(higherElementType); |
| |
| // Check that the promoted type can handle the intermediate values without |
| // overflowing. This can be interpreted as: |
| // (SmallerType.LargestFiniteVal * SmallerType.LargestFiniteVal) * 2 <= |
| // LargerType.LargestFiniteVal. |
| // In terms of exponent it gives this formula: |
| // (SmallerType.LargestFiniteVal * SmallerType.LargestFiniteVal |
| // doubles the exponent of SmallerType.LargestFiniteVal) |
| if (llvm::APFloat::semanticsMaxExponent(elementTypeSemantics) * 2 + 1 <= |
| llvm::APFloat::semanticsMaxExponent(higherElementTypeSemantics)) { |
| return higherElementType; |
| } |
| |
| // The intermediate values can't be represented in the promoted type |
| // without overflowing. |
| return {}; |
| } |
| |
| static mlir::Value |
| lowerComplexDiv(LoweringPreparePass &pass, CIRBaseBuilderTy &builder, |
| mlir::Location loc, cir::ComplexDivOp op, mlir::Value lhsReal, |
| mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag, |
| mlir::MLIRContext &mlirCx, clang::ASTContext &cc) { |
| cir::ComplexType complexTy = op.getType(); |
| if (mlir::isa<cir::FPTypeInterface>(complexTy.getElementType())) { |
| cir::ComplexRangeKind range = op.getRange(); |
| if (range == cir::ComplexRangeKind::Improved) |
| return buildRangeReductionComplexDiv(builder, loc, lhsReal, lhsImag, |
| rhsReal, rhsImag); |
| |
| if (range == cir::ComplexRangeKind::Full) |
| return buildComplexBinOpLibCall(pass, builder, &getComplexDivLibCallName, |
| loc, complexTy, lhsReal, lhsImag, rhsReal, |
| rhsImag); |
| |
| if (range == cir::ComplexRangeKind::Promoted) { |
| mlir::Type originalElementType = complexTy.getElementType(); |
| mlir::Type higherPrecisionElementType = |
| higherPrecisionElementTypeForComplexArithmetic(mlirCx, cc, builder, |
| originalElementType); |
| |
| if (!higherPrecisionElementType) |
| return buildRangeReductionComplexDiv(builder, loc, lhsReal, lhsImag, |
| rhsReal, rhsImag); |
| |
| cir::CastKind floatingCastKind = cir::CastKind::floating; |
| lhsReal = builder.createCast(floatingCastKind, lhsReal, |
| higherPrecisionElementType); |
| lhsImag = builder.createCast(floatingCastKind, lhsImag, |
| higherPrecisionElementType); |
| rhsReal = builder.createCast(floatingCastKind, rhsReal, |
| higherPrecisionElementType); |
| rhsImag = builder.createCast(floatingCastKind, rhsImag, |
| higherPrecisionElementType); |
| |
| mlir::Value algebraicResult = buildAlgebraicComplexDiv( |
| builder, loc, lhsReal, lhsImag, rhsReal, rhsImag); |
| |
| mlir::Value resultReal = builder.createComplexReal(loc, algebraicResult); |
| mlir::Value resultImag = builder.createComplexImag(loc, algebraicResult); |
| |
| mlir::Value finalReal = |
| builder.createCast(floatingCastKind, resultReal, originalElementType); |
| mlir::Value finalImag = |
| builder.createCast(floatingCastKind, resultImag, originalElementType); |
| return builder.createComplexCreate(loc, finalReal, finalImag); |
| } |
| } |
| |
| return buildAlgebraicComplexDiv(builder, loc, lhsReal, lhsImag, rhsReal, |
| rhsImag); |
| } |
| |
| void LoweringPreparePass::lowerComplexDivOp(cir::ComplexDivOp op) { |
| cir::CIRBaseBuilderTy builder(getContext()); |
| builder.setInsertionPointAfter(op); |
| mlir::Location loc = op.getLoc(); |
| mlir::TypedValue<cir::ComplexType> lhs = op.getLhs(); |
| mlir::TypedValue<cir::ComplexType> rhs = op.getRhs(); |
| mlir::Value lhsReal = builder.createComplexReal(loc, lhs); |
| mlir::Value lhsImag = builder.createComplexImag(loc, lhs); |
| mlir::Value rhsReal = builder.createComplexReal(loc, rhs); |
| mlir::Value rhsImag = builder.createComplexImag(loc, rhs); |
| |
| mlir::Value loweredResult = |
| lowerComplexDiv(*this, builder, loc, op, lhsReal, lhsImag, rhsReal, |
| rhsImag, getContext(), *astCtx); |
| op.replaceAllUsesWith(loweredResult); |
| op.erase(); |
| } |
| |
| static llvm::StringRef |
| getComplexMulLibCallName(llvm::APFloat::Semantics semantics) { |
| switch (semantics) { |
| case llvm::APFloat::S_IEEEhalf: |
| return "__mulhc3"; |
| case llvm::APFloat::S_IEEEsingle: |
| return "__mulsc3"; |
| case llvm::APFloat::S_IEEEdouble: |
| return "__muldc3"; |
| case llvm::APFloat::S_PPCDoubleDouble: |
| return "__multc3"; |
| case llvm::APFloat::S_x87DoubleExtended: |
| return "__mulxc3"; |
| case llvm::APFloat::S_IEEEquad: |
| return "__multc3"; |
| default: |
| llvm_unreachable("unsupported floating point type"); |
| } |
| } |
| |
| static mlir::Value lowerComplexMul(LoweringPreparePass &pass, |
| CIRBaseBuilderTy &builder, |
| mlir::Location loc, cir::ComplexMulOp op, |
| mlir::Value lhsReal, mlir::Value lhsImag, |
| mlir::Value rhsReal, mlir::Value rhsImag) { |
| // (a+bi) * (c+di) = (ac-bd) + (ad+bc)i |
| mlir::Value resultRealLhs = |
| builder.createBinop(loc, lhsReal, cir::BinOpKind::Mul, rhsReal); |
| mlir::Value resultRealRhs = |
| builder.createBinop(loc, lhsImag, cir::BinOpKind::Mul, rhsImag); |
| mlir::Value resultImagLhs = |
| builder.createBinop(loc, lhsReal, cir::BinOpKind::Mul, rhsImag); |
| mlir::Value resultImagRhs = |
| builder.createBinop(loc, lhsImag, cir::BinOpKind::Mul, rhsReal); |
| mlir::Value resultReal = builder.createBinop( |
| loc, resultRealLhs, cir::BinOpKind::Sub, resultRealRhs); |
| mlir::Value resultImag = builder.createBinop( |
| loc, resultImagLhs, cir::BinOpKind::Add, resultImagRhs); |
| mlir::Value algebraicResult = |
| builder.createComplexCreate(loc, resultReal, resultImag); |
| |
| cir::ComplexType complexTy = op.getType(); |
| cir::ComplexRangeKind rangeKind = op.getRange(); |
| if (mlir::isa<cir::IntType>(complexTy.getElementType()) || |
| rangeKind == cir::ComplexRangeKind::Basic || |
| rangeKind == cir::ComplexRangeKind::Improved || |
| rangeKind == cir::ComplexRangeKind::Promoted) |
| return algebraicResult; |
| |
| assert(!cir::MissingFeatures::fastMathFlags()); |
| |
| // Check whether the real part and the imaginary part of the result are both |
| // NaN. If so, emit a library call to compute the multiplication instead. |
| // We check a value against NaN by comparing the value against itself. |
| mlir::Value resultRealIsNaN = builder.createIsNaN(loc, resultReal); |
| mlir::Value resultImagIsNaN = builder.createIsNaN(loc, resultImag); |
| mlir::Value resultRealAndImagAreNaN = |
| builder.createLogicalAnd(loc, resultRealIsNaN, resultImagIsNaN); |
| |
| return builder |
| .create<cir::TernaryOp>( |
| loc, resultRealAndImagAreNaN, |
| [&](mlir::OpBuilder &, mlir::Location) { |
| mlir::Value libCallResult = buildComplexBinOpLibCall( |
| pass, builder, &getComplexMulLibCallName, loc, complexTy, |
| lhsReal, lhsImag, rhsReal, rhsImag); |
| builder.createYield(loc, libCallResult); |
| }, |
| [&](mlir::OpBuilder &, mlir::Location) { |
| builder.createYield(loc, algebraicResult); |
| }) |
| .getResult(); |
| } |
| |
| void LoweringPreparePass::lowerComplexMulOp(cir::ComplexMulOp op) { |
| cir::CIRBaseBuilderTy builder(getContext()); |
| builder.setInsertionPointAfter(op); |
| mlir::Location loc = op.getLoc(); |
| mlir::TypedValue<cir::ComplexType> lhs = op.getLhs(); |
| mlir::TypedValue<cir::ComplexType> rhs = op.getRhs(); |
| mlir::Value lhsReal = builder.createComplexReal(loc, lhs); |
| mlir::Value lhsImag = builder.createComplexImag(loc, lhs); |
| mlir::Value rhsReal = builder.createComplexReal(loc, rhs); |
| mlir::Value rhsImag = builder.createComplexImag(loc, rhs); |
| mlir::Value loweredResult = lowerComplexMul(*this, builder, loc, op, lhsReal, |
| lhsImag, rhsReal, rhsImag); |
| op.replaceAllUsesWith(loweredResult); |
| op.erase(); |
| } |
| |
| void LoweringPreparePass::lowerUnaryOp(cir::UnaryOp op) { |
| mlir::Type ty = op.getType(); |
| if (!mlir::isa<cir::ComplexType>(ty)) |
| return; |
| |
| mlir::Location loc = op.getLoc(); |
| cir::UnaryOpKind opKind = op.getKind(); |
| |
| CIRBaseBuilderTy builder(getContext()); |
| builder.setInsertionPointAfter(op); |
| |
| mlir::Value operand = op.getInput(); |
| mlir::Value operandReal = builder.createComplexReal(loc, operand); |
| mlir::Value operandImag = builder.createComplexImag(loc, operand); |
| |
| mlir::Value resultReal; |
| mlir::Value resultImag; |
| |
| switch (opKind) { |
| case cir::UnaryOpKind::Inc: |
| case cir::UnaryOpKind::Dec: |
| resultReal = builder.createUnaryOp(loc, opKind, operandReal); |
| resultImag = operandImag; |
| break; |
| |
| case cir::UnaryOpKind::Plus: |
| case cir::UnaryOpKind::Minus: |
| resultReal = builder.createUnaryOp(loc, opKind, operandReal); |
| resultImag = builder.createUnaryOp(loc, opKind, operandImag); |
| break; |
| |
| case cir::UnaryOpKind::Not: |
| resultReal = operandReal; |
| resultImag = |
| builder.createUnaryOp(loc, cir::UnaryOpKind::Minus, operandImag); |
| break; |
| } |
| |
| mlir::Value result = builder.createComplexCreate(loc, resultReal, resultImag); |
| op.replaceAllUsesWith(result); |
| op.erase(); |
| } |
| |
| cir::FuncOp |
| LoweringPreparePass::buildCXXGlobalVarDeclInitFunc(cir::GlobalOp op) { |
| // TODO(cir): Store this in the GlobalOp. |
| // This should come from the MangleContext, but for now I'm hardcoding it. |
| SmallString<256> fnName("__cxx_global_var_init"); |
| // Get a unique name |
| uint32_t cnt = dynamicInitializerNames[fnName]++; |
| if (cnt) |
| fnName += "." + llvm::Twine(cnt).str(); |
| |
| // Create a variable initialization function. |
| CIRBaseBuilderTy builder(getContext()); |
| builder.setInsertionPointAfter(op); |
| cir::VoidType voidTy = builder.getVoidTy(); |
| auto fnType = cir::FuncType::get({}, voidTy); |
| FuncOp f = buildRuntimeFunction(builder, fnName, op.getLoc(), fnType, |
| cir::GlobalLinkageKind::InternalLinkage); |
| |
| // Move over the initialzation code of the ctor region. |
| mlir::Block *entryBB = f.addEntryBlock(); |
| if (!op.getCtorRegion().empty()) { |
| mlir::Block &block = op.getCtorRegion().front(); |
| entryBB->getOperations().splice(entryBB->begin(), block.getOperations(), |
| block.begin(), std::prev(block.end())); |
| } |
| |
| // Register the destructor call with __cxa_atexit |
| mlir::Region &dtorRegion = op.getDtorRegion(); |
| if (!dtorRegion.empty()) { |
| assert(!cir::MissingFeatures::astVarDeclInterface()); |
| assert(!cir::MissingFeatures::opGlobalThreadLocal()); |
| // Create a variable that binds the atexit to this shared object. |
| builder.setInsertionPointToStart(&mlirModule.getBodyRegion().front()); |
| cir::GlobalOp handle = buildRuntimeVariable( |
| builder, "__dso_handle", op.getLoc(), builder.getI8Type(), |
| cir::GlobalLinkageKind::ExternalLinkage, cir::VisibilityKind::Hidden); |
| |
| // Look for the destructor call in dtorBlock |
| mlir::Block &dtorBlock = dtorRegion.front(); |
| cir::CallOp dtorCall; |
| for (auto op : reverse(dtorBlock.getOps<cir::CallOp>())) { |
| dtorCall = op; |
| break; |
| } |
| assert(dtorCall && "Expected a dtor call"); |
| cir::FuncOp dtorFunc = getCalledFunction(dtorCall); |
| assert(dtorFunc && "Expected a dtor call"); |
| |
| // Create a runtime helper function: |
| // extern "C" int __cxa_atexit(void (*f)(void *), void *p, void *d); |
| auto voidPtrTy = cir::PointerType::get(voidTy); |
| auto voidFnTy = cir::FuncType::get({voidPtrTy}, voidTy); |
| auto voidFnPtrTy = cir::PointerType::get(voidFnTy); |
| auto handlePtrTy = cir::PointerType::get(handle.getSymType()); |
| auto fnAtExitType = |
| cir::FuncType::get({voidFnPtrTy, voidPtrTy, handlePtrTy}, voidTy); |
| const char *nameAtExit = "__cxa_atexit"; |
| cir::FuncOp fnAtExit = |
| buildRuntimeFunction(builder, nameAtExit, op.getLoc(), fnAtExitType); |
| |
| // Replace the dtor call with a call to __cxa_atexit(&dtor, &var, |
| // &__dso_handle) |
| builder.setInsertionPointAfter(dtorCall); |
| mlir::Value args[3]; |
| auto dtorPtrTy = cir::PointerType::get(dtorFunc.getFunctionType()); |
| // dtorPtrTy |
| args[0] = cir::GetGlobalOp::create(builder, dtorCall.getLoc(), dtorPtrTy, |
| dtorFunc.getSymName()); |
| args[0] = cir::CastOp::create(builder, dtorCall.getLoc(), voidFnPtrTy, |
| cir::CastKind::bitcast, args[0]); |
| args[1] = |
| cir::CastOp::create(builder, dtorCall.getLoc(), voidPtrTy, |
| cir::CastKind::bitcast, dtorCall.getArgOperand(0)); |
| args[2] = cir::GetGlobalOp::create(builder, handle.getLoc(), handlePtrTy, |
| handle.getSymName()); |
| builder.createCallOp(dtorCall.getLoc(), fnAtExit, args); |
| dtorCall->erase(); |
| entryBB->getOperations().splice(entryBB->end(), dtorBlock.getOperations(), |
| dtorBlock.begin(), |
| std::prev(dtorBlock.end())); |
| } |
| |
| // Replace cir.yield with cir.return |
| builder.setInsertionPointToEnd(entryBB); |
| mlir::Operation *yieldOp = nullptr; |
| if (!op.getCtorRegion().empty()) { |
| mlir::Block &block = op.getCtorRegion().front(); |
| yieldOp = &block.getOperations().back(); |
| } else { |
| assert(!dtorRegion.empty()); |
| mlir::Block &block = dtorRegion.front(); |
| yieldOp = &block.getOperations().back(); |
| } |
| |
| assert(isa<cir::YieldOp>(*yieldOp)); |
| cir::ReturnOp::create(builder, yieldOp->getLoc()); |
| return f; |
| } |
| |
| void LoweringPreparePass::lowerGlobalOp(GlobalOp op) { |
| mlir::Region &ctorRegion = op.getCtorRegion(); |
| mlir::Region &dtorRegion = op.getDtorRegion(); |
| |
| if (!ctorRegion.empty() || !dtorRegion.empty()) { |
| // Build a variable initialization function and move the initialzation code |
| // in the ctor region over. |
| cir::FuncOp f = buildCXXGlobalVarDeclInitFunc(op); |
| |
| // Clear the ctor and dtor region |
| ctorRegion.getBlocks().clear(); |
| dtorRegion.getBlocks().clear(); |
| |
| assert(!cir::MissingFeatures::astVarDeclInterface()); |
| dynamicInitializers.push_back(f); |
| } |
| |
| assert(!cir::MissingFeatures::opGlobalAnnotations()); |
| } |
| |
| template <typename AttributeTy> |
| static llvm::SmallVector<mlir::Attribute> |
| prepareCtorDtorAttrList(mlir::MLIRContext *context, |
| llvm::ArrayRef<std::pair<std::string, uint32_t>> list) { |
| llvm::SmallVector<mlir::Attribute> attrs; |
| for (const auto &[name, priority] : list) |
| attrs.push_back(AttributeTy::get(context, name, priority)); |
| return attrs; |
| } |
| |
| void LoweringPreparePass::buildGlobalCtorDtorList() { |
| if (!globalCtorList.empty()) { |
| llvm::SmallVector<mlir::Attribute> globalCtors = |
| prepareCtorDtorAttrList<cir::GlobalCtorAttr>(&getContext(), |
| globalCtorList); |
| |
| mlirModule->setAttr(cir::CIRDialect::getGlobalCtorsAttrName(), |
| mlir::ArrayAttr::get(&getContext(), globalCtors)); |
| } |
| |
| // We will eventual need to populate a global_dtor list, but that's not |
| // needed for globals with destructors. It will only be needed for functions |
| // that are marked as global destructors with an attribute. |
| assert(!cir::MissingFeatures::opGlobalDtorList()); |
| } |
| |
| void LoweringPreparePass::buildCXXGlobalInitFunc() { |
| if (dynamicInitializers.empty()) |
| return; |
| |
| // TODO: handle globals with a user-specified initialzation priority. |
| // TODO: handle default priority more nicely. |
| assert(!cir::MissingFeatures::opGlobalCtorPriority()); |
| |
| SmallString<256> fnName; |
| // Include the filename in the symbol name. Including "sub_" matches gcc |
| // and makes sure these symbols appear lexicographically behind the symbols |
| // with priority (TBD). Module implementation units behave the same |
| // way as a non-modular TU with imports. |
| // TODO: check CXX20ModuleInits |
| if (astCtx->getCurrentNamedModule() && |
| !astCtx->getCurrentNamedModule()->isModuleImplementation()) { |
| llvm::raw_svector_ostream out(fnName); |
| std::unique_ptr<clang::MangleContext> mangleCtx( |
| astCtx->createMangleContext()); |
| cast<clang::ItaniumMangleContext>(*mangleCtx) |
| .mangleModuleInitializer(astCtx->getCurrentNamedModule(), out); |
| } else { |
| fnName += "_GLOBAL__sub_I_"; |
| fnName += getTransformedFileName(mlirModule); |
| } |
| |
| CIRBaseBuilderTy builder(getContext()); |
| builder.setInsertionPointToEnd(&mlirModule.getBodyRegion().back()); |
| auto fnType = cir::FuncType::get({}, builder.getVoidTy()); |
| cir::FuncOp f = |
| buildRuntimeFunction(builder, fnName, mlirModule.getLoc(), fnType, |
| cir::GlobalLinkageKind::ExternalLinkage); |
| builder.setInsertionPointToStart(f.addEntryBlock()); |
| for (cir::FuncOp &f : dynamicInitializers) |
| builder.createCallOp(f.getLoc(), f, {}); |
| // Add the global init function (not the individual ctor functions) to the |
| // global ctor list. |
| globalCtorList.emplace_back(fnName, |
| cir::GlobalCtorAttr::getDefaultPriority()); |
| |
| cir::ReturnOp::create(builder, f.getLoc()); |
| } |
| |
| static void lowerArrayDtorCtorIntoLoop(cir::CIRBaseBuilderTy &builder, |
| clang::ASTContext *astCtx, |
| mlir::Operation *op, mlir::Type eltTy, |
| mlir::Value arrayAddr, uint64_t arrayLen, |
| bool isCtor) { |
| // Generate loop to call into ctor/dtor for every element. |
| mlir::Location loc = op->getLoc(); |
| |
| // TODO: instead of getting the size from the AST context, create alias for |
| // PtrDiffTy and unify with CIRGen stuff. |
| const unsigned sizeTypeSize = |
| astCtx->getTypeSize(astCtx->getSignedSizeType()); |
| uint64_t endOffset = isCtor ? arrayLen : arrayLen - 1; |
| mlir::Value endOffsetVal = |
| builder.getUnsignedInt(loc, endOffset, sizeTypeSize); |
| |
| auto begin = cir::CastOp::create(builder, loc, eltTy, |
| cir::CastKind::array_to_ptrdecay, arrayAddr); |
| mlir::Value end = |
| cir::PtrStrideOp::create(builder, loc, eltTy, begin, endOffsetVal); |
| mlir::Value start = isCtor ? begin : end; |
| mlir::Value stop = isCtor ? end : begin; |
| |
| mlir::Value tmpAddr = builder.createAlloca( |
| loc, /*addr type*/ builder.getPointerTo(eltTy), |
| /*var type*/ eltTy, "__array_idx", builder.getAlignmentAttr(1)); |
| builder.createStore(loc, start, tmpAddr); |
| |
| cir::DoWhileOp loop = builder.createDoWhile( |
| loc, |
| /*condBuilder=*/ |
| [&](mlir::OpBuilder &b, mlir::Location loc) { |
| auto currentElement = b.create<cir::LoadOp>(loc, eltTy, tmpAddr); |
| mlir::Type boolTy = cir::BoolType::get(b.getContext()); |
| auto cmp = builder.create<cir::CmpOp>(loc, boolTy, cir::CmpOpKind::ne, |
| currentElement, stop); |
| builder.createCondition(cmp); |
| }, |
| /*bodyBuilder=*/ |
| [&](mlir::OpBuilder &b, mlir::Location loc) { |
| auto currentElement = b.create<cir::LoadOp>(loc, eltTy, tmpAddr); |
| |
| cir::CallOp ctorCall; |
| op->walk([&](cir::CallOp c) { ctorCall = c; }); |
| assert(ctorCall && "expected ctor call"); |
| |
| // Array elements get constructed in order but destructed in reverse. |
| mlir::Value stride; |
| if (isCtor) |
| stride = builder.getUnsignedInt(loc, 1, sizeTypeSize); |
| else |
| stride = builder.getSignedInt(loc, -1, sizeTypeSize); |
| |
| ctorCall->moveBefore(stride.getDefiningOp()); |
| ctorCall->setOperand(0, currentElement); |
| auto nextElement = cir::PtrStrideOp::create(builder, loc, eltTy, |
| currentElement, stride); |
| |
| // Store the element pointer to the temporary variable |
| builder.createStore(loc, nextElement, tmpAddr); |
| builder.createYield(loc); |
| }); |
| |
| op->replaceAllUsesWith(loop); |
| op->erase(); |
| } |
| |
| void LoweringPreparePass::lowerArrayDtor(cir::ArrayDtor op) { |
| CIRBaseBuilderTy builder(getContext()); |
| builder.setInsertionPointAfter(op.getOperation()); |
| |
| mlir::Type eltTy = op->getRegion(0).getArgument(0).getType(); |
| assert(!cir::MissingFeatures::vlas()); |
| auto arrayLen = |
| mlir::cast<cir::ArrayType>(op.getAddr().getType().getPointee()).getSize(); |
| lowerArrayDtorCtorIntoLoop(builder, astCtx, op, eltTy, op.getAddr(), arrayLen, |
| false); |
| } |
| |
| void LoweringPreparePass::lowerArrayCtor(cir::ArrayCtor op) { |
| cir::CIRBaseBuilderTy builder(getContext()); |
| builder.setInsertionPointAfter(op.getOperation()); |
| |
| mlir::Type eltTy = op->getRegion(0).getArgument(0).getType(); |
| assert(!cir::MissingFeatures::vlas()); |
| auto arrayLen = |
| mlir::cast<cir::ArrayType>(op.getAddr().getType().getPointee()).getSize(); |
| lowerArrayDtorCtorIntoLoop(builder, astCtx, op, eltTy, op.getAddr(), arrayLen, |
| true); |
| } |
| |
| void LoweringPreparePass::runOnOp(mlir::Operation *op) { |
| if (auto arrayCtor = dyn_cast<ArrayCtor>(op)) |
| lowerArrayCtor(arrayCtor); |
| else if (auto arrayDtor = dyn_cast<cir::ArrayDtor>(op)) |
| lowerArrayDtor(arrayDtor); |
| else if (auto cast = mlir::dyn_cast<cir::CastOp>(op)) |
| lowerCastOp(cast); |
| else if (auto complexDiv = mlir::dyn_cast<cir::ComplexDivOp>(op)) |
| lowerComplexDivOp(complexDiv); |
| else if (auto complexMul = mlir::dyn_cast<cir::ComplexMulOp>(op)) |
| lowerComplexMulOp(complexMul); |
| else if (auto glob = mlir::dyn_cast<cir::GlobalOp>(op)) |
| lowerGlobalOp(glob); |
| else if (auto unary = mlir::dyn_cast<cir::UnaryOp>(op)) |
| lowerUnaryOp(unary); |
| } |
| |
| void LoweringPreparePass::runOnOperation() { |
| mlir::Operation *op = getOperation(); |
| if (isa<::mlir::ModuleOp>(op)) |
| mlirModule = cast<::mlir::ModuleOp>(op); |
| |
| llvm::SmallVector<mlir::Operation *> opsToTransform; |
| |
| op->walk([&](mlir::Operation *op) { |
| if (mlir::isa<cir::ArrayCtor, cir::ArrayDtor, cir::CastOp, |
| cir::ComplexMulOp, cir::ComplexDivOp, cir::GlobalOp, |
| cir::UnaryOp>(op)) |
| opsToTransform.push_back(op); |
| }); |
| |
| for (mlir::Operation *o : opsToTransform) |
| runOnOp(o); |
| |
| buildCXXGlobalInitFunc(); |
| buildGlobalCtorDtorList(); |
| } |
| |
| std::unique_ptr<Pass> mlir::createLoweringPreparePass() { |
| return std::make_unique<LoweringPreparePass>(); |
| } |
| |
| std::unique_ptr<Pass> |
| mlir::createLoweringPreparePass(clang::ASTContext *astCtx) { |
| auto pass = std::make_unique<LoweringPreparePass>(); |
| pass->setASTContext(astCtx); |
| return std::move(pass); |
| } |