| //===- LoweringPrepare.cpp - pareparation work for LLVM lowering ----------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "PassDetail.h" |
| #include "clang/AST/ASTContext.h" |
| #include "clang/Basic/TargetInfo.h" |
| #include "clang/CIR/Dialect/Builder/CIRBaseBuilder.h" |
| #include "clang/CIR/Dialect/IR/CIRDialect.h" |
| #include "clang/CIR/Dialect/IR/CIROpsEnums.h" |
| #include "clang/CIR/Dialect/Passes.h" |
| #include "clang/CIR/MissingFeatures.h" |
| |
| #include <memory> |
| |
| using namespace mlir; |
| using namespace cir; |
| |
| namespace { |
| struct LoweringPreparePass : public LoweringPrepareBase<LoweringPreparePass> { |
| LoweringPreparePass() = default; |
| void runOnOperation() override; |
| |
| void runOnOp(mlir::Operation *op); |
| void lowerCastOp(cir::CastOp op); |
| void lowerComplexDivOp(cir::ComplexDivOp op); |
| void lowerComplexMulOp(cir::ComplexMulOp op); |
| void lowerUnaryOp(cir::UnaryOp op); |
| void lowerArrayDtor(cir::ArrayDtor op); |
| void lowerArrayCtor(cir::ArrayCtor op); |
| |
| cir::FuncOp buildRuntimeFunction( |
| mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc, |
| cir::FuncType type, |
| cir::GlobalLinkageKind linkage = cir::GlobalLinkageKind::ExternalLinkage); |
| |
| /// |
| /// AST related |
| /// ----------- |
| |
| clang::ASTContext *astCtx; |
| |
| /// Tracks current module. |
| mlir::ModuleOp mlirModule; |
| |
| void setASTContext(clang::ASTContext *c) { astCtx = c; } |
| }; |
| |
| } // namespace |
| |
| cir::FuncOp LoweringPreparePass::buildRuntimeFunction( |
| mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc, |
| cir::FuncType type, cir::GlobalLinkageKind linkage) { |
| cir::FuncOp f = dyn_cast_or_null<FuncOp>(SymbolTable::lookupNearestSymbolFrom( |
| mlirModule, StringAttr::get(mlirModule->getContext(), name))); |
| if (!f) { |
| f = builder.create<cir::FuncOp>(loc, name, type); |
| f.setLinkageAttr( |
| cir::GlobalLinkageKindAttr::get(builder.getContext(), linkage)); |
| mlir::SymbolTable::setSymbolVisibility( |
| f, mlir::SymbolTable::Visibility::Private); |
| |
| assert(!cir::MissingFeatures::opFuncExtraAttrs()); |
| } |
| return f; |
| } |
| |
| static mlir::Value lowerScalarToComplexCast(mlir::MLIRContext &ctx, |
| cir::CastOp op) { |
| cir::CIRBaseBuilderTy builder(ctx); |
| builder.setInsertionPoint(op); |
| |
| mlir::Value src = op.getSrc(); |
| mlir::Value imag = builder.getNullValue(src.getType(), op.getLoc()); |
| return builder.createComplexCreate(op.getLoc(), src, imag); |
| } |
| |
| static mlir::Value lowerComplexToScalarCast(mlir::MLIRContext &ctx, |
| cir::CastOp op, |
| cir::CastKind elemToBoolKind) { |
| cir::CIRBaseBuilderTy builder(ctx); |
| builder.setInsertionPoint(op); |
| |
| mlir::Value src = op.getSrc(); |
| if (!mlir::isa<cir::BoolType>(op.getType())) |
| return builder.createComplexReal(op.getLoc(), src); |
| |
| // Complex cast to bool: (bool)(a+bi) => (bool)a || (bool)b |
| mlir::Value srcReal = builder.createComplexReal(op.getLoc(), src); |
| mlir::Value srcImag = builder.createComplexImag(op.getLoc(), src); |
| |
| cir::BoolType boolTy = builder.getBoolTy(); |
| mlir::Value srcRealToBool = |
| builder.createCast(op.getLoc(), elemToBoolKind, srcReal, boolTy); |
| mlir::Value srcImagToBool = |
| builder.createCast(op.getLoc(), elemToBoolKind, srcImag, boolTy); |
| return builder.createLogicalOr(op.getLoc(), srcRealToBool, srcImagToBool); |
| } |
| |
| static mlir::Value lowerComplexToComplexCast(mlir::MLIRContext &ctx, |
| cir::CastOp op, |
| cir::CastKind scalarCastKind) { |
| CIRBaseBuilderTy builder(ctx); |
| builder.setInsertionPoint(op); |
| |
| mlir::Value src = op.getSrc(); |
| auto dstComplexElemTy = |
| mlir::cast<cir::ComplexType>(op.getType()).getElementType(); |
| |
| mlir::Value srcReal = builder.createComplexReal(op.getLoc(), src); |
| mlir::Value srcImag = builder.createComplexImag(op.getLoc(), src); |
| |
| mlir::Value dstReal = builder.createCast(op.getLoc(), scalarCastKind, srcReal, |
| dstComplexElemTy); |
| mlir::Value dstImag = builder.createCast(op.getLoc(), scalarCastKind, srcImag, |
| dstComplexElemTy); |
| return builder.createComplexCreate(op.getLoc(), dstReal, dstImag); |
| } |
| |
| void LoweringPreparePass::lowerCastOp(cir::CastOp op) { |
| mlir::MLIRContext &ctx = getContext(); |
| mlir::Value loweredValue = [&]() -> mlir::Value { |
| switch (op.getKind()) { |
| case cir::CastKind::float_to_complex: |
| case cir::CastKind::int_to_complex: |
| return lowerScalarToComplexCast(ctx, op); |
| case cir::CastKind::float_complex_to_real: |
| case cir::CastKind::int_complex_to_real: |
| return lowerComplexToScalarCast(ctx, op, op.getKind()); |
| case cir::CastKind::float_complex_to_bool: |
| return lowerComplexToScalarCast(ctx, op, cir::CastKind::float_to_bool); |
| case cir::CastKind::int_complex_to_bool: |
| return lowerComplexToScalarCast(ctx, op, cir::CastKind::int_to_bool); |
| case cir::CastKind::float_complex: |
| return lowerComplexToComplexCast(ctx, op, cir::CastKind::floating); |
| case cir::CastKind::float_complex_to_int_complex: |
| return lowerComplexToComplexCast(ctx, op, cir::CastKind::float_to_int); |
| case cir::CastKind::int_complex: |
| return lowerComplexToComplexCast(ctx, op, cir::CastKind::integral); |
| case cir::CastKind::int_complex_to_float_complex: |
| return lowerComplexToComplexCast(ctx, op, cir::CastKind::int_to_float); |
| default: |
| return nullptr; |
| } |
| }(); |
| |
| if (loweredValue) { |
| op.replaceAllUsesWith(loweredValue); |
| op.erase(); |
| } |
| } |
| |
| static mlir::Value buildComplexBinOpLibCall( |
| LoweringPreparePass &pass, CIRBaseBuilderTy &builder, |
| llvm::StringRef (*libFuncNameGetter)(llvm::APFloat::Semantics), |
| mlir::Location loc, cir::ComplexType ty, mlir::Value lhsReal, |
| mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag) { |
| cir::FPTypeInterface elementTy = |
| mlir::cast<cir::FPTypeInterface>(ty.getElementType()); |
| |
| llvm::StringRef libFuncName = libFuncNameGetter( |
| llvm::APFloat::SemanticsToEnum(elementTy.getFloatSemantics())); |
| llvm::SmallVector<mlir::Type, 4> libFuncInputTypes(4, elementTy); |
| |
| cir::FuncType libFuncTy = cir::FuncType::get(libFuncInputTypes, ty); |
| |
| // Insert a declaration for the runtime function to be used in Complex |
| // multiplication and division when needed |
| cir::FuncOp libFunc; |
| { |
| mlir::OpBuilder::InsertionGuard ipGuard{builder}; |
| builder.setInsertionPointToStart(pass.mlirModule.getBody()); |
| libFunc = pass.buildRuntimeFunction(builder, libFuncName, loc, libFuncTy); |
| } |
| |
| cir::CallOp call = |
| builder.createCallOp(loc, libFunc, {lhsReal, lhsImag, rhsReal, rhsImag}); |
| return call.getResult(); |
| } |
| |
| static llvm::StringRef |
| getComplexDivLibCallName(llvm::APFloat::Semantics semantics) { |
| switch (semantics) { |
| case llvm::APFloat::S_IEEEhalf: |
| return "__divhc3"; |
| case llvm::APFloat::S_IEEEsingle: |
| return "__divsc3"; |
| case llvm::APFloat::S_IEEEdouble: |
| return "__divdc3"; |
| case llvm::APFloat::S_PPCDoubleDouble: |
| return "__divtc3"; |
| case llvm::APFloat::S_x87DoubleExtended: |
| return "__divxc3"; |
| case llvm::APFloat::S_IEEEquad: |
| return "__divtc3"; |
| default: |
| llvm_unreachable("unsupported floating point type"); |
| } |
| } |
| |
| static mlir::Value |
| buildAlgebraicComplexDiv(CIRBaseBuilderTy &builder, mlir::Location loc, |
| mlir::Value lhsReal, mlir::Value lhsImag, |
| mlir::Value rhsReal, mlir::Value rhsImag) { |
| // (a+bi) / (c+di) = ((ac+bd)/(cc+dd)) + ((bc-ad)/(cc+dd))i |
| mlir::Value &a = lhsReal; |
| mlir::Value &b = lhsImag; |
| mlir::Value &c = rhsReal; |
| mlir::Value &d = rhsImag; |
| |
| mlir::Value ac = builder.createBinop(loc, a, cir::BinOpKind::Mul, c); // a*c |
| mlir::Value bd = builder.createBinop(loc, b, cir::BinOpKind::Mul, d); // b*d |
| mlir::Value cc = builder.createBinop(loc, c, cir::BinOpKind::Mul, c); // c*c |
| mlir::Value dd = builder.createBinop(loc, d, cir::BinOpKind::Mul, d); // d*d |
| mlir::Value acbd = |
| builder.createBinop(loc, ac, cir::BinOpKind::Add, bd); // ac+bd |
| mlir::Value ccdd = |
| builder.createBinop(loc, cc, cir::BinOpKind::Add, dd); // cc+dd |
| mlir::Value resultReal = |
| builder.createBinop(loc, acbd, cir::BinOpKind::Div, ccdd); |
| |
| mlir::Value bc = builder.createBinop(loc, b, cir::BinOpKind::Mul, c); // b*c |
| mlir::Value ad = builder.createBinop(loc, a, cir::BinOpKind::Mul, d); // a*d |
| mlir::Value bcad = |
| builder.createBinop(loc, bc, cir::BinOpKind::Sub, ad); // bc-ad |
| mlir::Value resultImag = |
| builder.createBinop(loc, bcad, cir::BinOpKind::Div, ccdd); |
| return builder.createComplexCreate(loc, resultReal, resultImag); |
| } |
| |
| static mlir::Value |
| buildRangeReductionComplexDiv(CIRBaseBuilderTy &builder, mlir::Location loc, |
| mlir::Value lhsReal, mlir::Value lhsImag, |
| mlir::Value rhsReal, mlir::Value rhsImag) { |
| // Implements Smith's algorithm for complex division. |
| // SMITH, R. L. Algorithm 116: Complex division. Commun. ACM 5, 8 (1962). |
| |
| // Let: |
| // - lhs := a+bi |
| // - rhs := c+di |
| // - result := lhs / rhs = e+fi |
| // |
| // The algorithm pseudocode looks like follows: |
| // if fabs(c) >= fabs(d): |
| // r := d / c |
| // tmp := c + r*d |
| // e = (a + b*r) / tmp |
| // f = (b - a*r) / tmp |
| // else: |
| // r := c / d |
| // tmp := d + r*c |
| // e = (a*r + b) / tmp |
| // f = (b*r - a) / tmp |
| |
| mlir::Value &a = lhsReal; |
| mlir::Value &b = lhsImag; |
| mlir::Value &c = rhsReal; |
| mlir::Value &d = rhsImag; |
| |
| auto trueBranchBuilder = [&](mlir::OpBuilder &, mlir::Location) { |
| mlir::Value r = builder.createBinop(loc, d, cir::BinOpKind::Div, |
| c); // r := d / c |
| mlir::Value rd = builder.createBinop(loc, r, cir::BinOpKind::Mul, d); // r*d |
| mlir::Value tmp = builder.createBinop(loc, c, cir::BinOpKind::Add, |
| rd); // tmp := c + r*d |
| |
| mlir::Value br = builder.createBinop(loc, b, cir::BinOpKind::Mul, r); // b*r |
| mlir::Value abr = |
| builder.createBinop(loc, a, cir::BinOpKind::Add, br); // a + b*r |
| mlir::Value e = builder.createBinop(loc, abr, cir::BinOpKind::Div, tmp); |
| |
| mlir::Value ar = builder.createBinop(loc, a, cir::BinOpKind::Mul, r); // a*r |
| mlir::Value bar = |
| builder.createBinop(loc, b, cir::BinOpKind::Sub, ar); // b - a*r |
| mlir::Value f = builder.createBinop(loc, bar, cir::BinOpKind::Div, tmp); |
| |
| mlir::Value result = builder.createComplexCreate(loc, e, f); |
| builder.createYield(loc, result); |
| }; |
| |
| auto falseBranchBuilder = [&](mlir::OpBuilder &, mlir::Location) { |
| mlir::Value r = builder.createBinop(loc, c, cir::BinOpKind::Div, |
| d); // r := c / d |
| mlir::Value rc = builder.createBinop(loc, r, cir::BinOpKind::Mul, c); // r*c |
| mlir::Value tmp = builder.createBinop(loc, d, cir::BinOpKind::Add, |
| rc); // tmp := d + r*c |
| |
| mlir::Value ar = builder.createBinop(loc, a, cir::BinOpKind::Mul, r); // a*r |
| mlir::Value arb = |
| builder.createBinop(loc, ar, cir::BinOpKind::Add, b); // a*r + b |
| mlir::Value e = builder.createBinop(loc, arb, cir::BinOpKind::Div, tmp); |
| |
| mlir::Value br = builder.createBinop(loc, b, cir::BinOpKind::Mul, r); // b*r |
| mlir::Value bra = |
| builder.createBinop(loc, br, cir::BinOpKind::Sub, a); // b*r - a |
| mlir::Value f = builder.createBinop(loc, bra, cir::BinOpKind::Div, tmp); |
| |
| mlir::Value result = builder.createComplexCreate(loc, e, f); |
| builder.createYield(loc, result); |
| }; |
| |
| auto cFabs = builder.create<cir::FAbsOp>(loc, c); |
| auto dFabs = builder.create<cir::FAbsOp>(loc, d); |
| cir::CmpOp cmpResult = |
| builder.createCompare(loc, cir::CmpOpKind::ge, cFabs, dFabs); |
| auto ternary = builder.create<cir::TernaryOp>( |
| loc, cmpResult, trueBranchBuilder, falseBranchBuilder); |
| |
| return ternary.getResult(); |
| } |
| |
| static mlir::Type higherPrecisionElementTypeForComplexArithmetic( |
| mlir::MLIRContext &context, clang::ASTContext &cc, |
| CIRBaseBuilderTy &builder, mlir::Type elementType) { |
| |
| auto getHigherPrecisionFPType = [&context](mlir::Type type) -> mlir::Type { |
| if (mlir::isa<cir::FP16Type>(type)) |
| return cir::SingleType::get(&context); |
| |
| if (mlir::isa<cir::SingleType>(type) || mlir::isa<cir::BF16Type>(type)) |
| return cir::DoubleType::get(&context); |
| |
| if (mlir::isa<cir::DoubleType>(type)) |
| return cir::LongDoubleType::get(&context, type); |
| |
| return type; |
| }; |
| |
| auto getFloatTypeSemantics = |
| [&cc](mlir::Type type) -> const llvm::fltSemantics & { |
| const clang::TargetInfo &info = cc.getTargetInfo(); |
| if (mlir::isa<cir::FP16Type>(type)) |
| return info.getHalfFormat(); |
| |
| if (mlir::isa<cir::BF16Type>(type)) |
| return info.getBFloat16Format(); |
| |
| if (mlir::isa<cir::SingleType>(type)) |
| return info.getFloatFormat(); |
| |
| if (mlir::isa<cir::DoubleType>(type)) |
| return info.getDoubleFormat(); |
| |
| if (mlir::isa<cir::LongDoubleType>(type)) { |
| if (cc.getLangOpts().OpenMP && cc.getLangOpts().OpenMPIsTargetDevice) |
| llvm_unreachable("NYI Float type semantics with OpenMP"); |
| return info.getLongDoubleFormat(); |
| } |
| |
| if (mlir::isa<cir::FP128Type>(type)) { |
| if (cc.getLangOpts().OpenMP && cc.getLangOpts().OpenMPIsTargetDevice) |
| llvm_unreachable("NYI Float type semantics with OpenMP"); |
| return info.getFloat128Format(); |
| } |
| |
| assert(false && "Unsupported float type semantics"); |
| }; |
| |
| const mlir::Type higherElementType = getHigherPrecisionFPType(elementType); |
| const llvm::fltSemantics &elementTypeSemantics = |
| getFloatTypeSemantics(elementType); |
| const llvm::fltSemantics &higherElementTypeSemantics = |
| getFloatTypeSemantics(higherElementType); |
| |
| // Check that the promoted type can handle the intermediate values without |
| // overflowing. This can be interpreted as: |
| // (SmallerType.LargestFiniteVal * SmallerType.LargestFiniteVal) * 2 <= |
| // LargerType.LargestFiniteVal. |
| // In terms of exponent it gives this formula: |
| // (SmallerType.LargestFiniteVal * SmallerType.LargestFiniteVal |
| // doubles the exponent of SmallerType.LargestFiniteVal) |
| if (llvm::APFloat::semanticsMaxExponent(elementTypeSemantics) * 2 + 1 <= |
| llvm::APFloat::semanticsMaxExponent(higherElementTypeSemantics)) { |
| return higherElementType; |
| } |
| |
| // The intermediate values can't be represented in the promoted type |
| // without overflowing. |
| return {}; |
| } |
| |
| static mlir::Value |
| lowerComplexDiv(LoweringPreparePass &pass, CIRBaseBuilderTy &builder, |
| mlir::Location loc, cir::ComplexDivOp op, mlir::Value lhsReal, |
| mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag, |
| mlir::MLIRContext &mlirCx, clang::ASTContext &cc) { |
| cir::ComplexType complexTy = op.getType(); |
| if (mlir::isa<cir::FPTypeInterface>(complexTy.getElementType())) { |
| cir::ComplexRangeKind range = op.getRange(); |
| if (range == cir::ComplexRangeKind::Improved) |
| return buildRangeReductionComplexDiv(builder, loc, lhsReal, lhsImag, |
| rhsReal, rhsImag); |
| |
| if (range == cir::ComplexRangeKind::Full) |
| return buildComplexBinOpLibCall(pass, builder, &getComplexDivLibCallName, |
| loc, complexTy, lhsReal, lhsImag, rhsReal, |
| rhsImag); |
| |
| if (range == cir::ComplexRangeKind::Promoted) { |
| mlir::Type originalElementType = complexTy.getElementType(); |
| mlir::Type higherPrecisionElementType = |
| higherPrecisionElementTypeForComplexArithmetic(mlirCx, cc, builder, |
| originalElementType); |
| |
| if (!higherPrecisionElementType) |
| return buildRangeReductionComplexDiv(builder, loc, lhsReal, lhsImag, |
| rhsReal, rhsImag); |
| |
| cir::CastKind floatingCastKind = cir::CastKind::floating; |
| lhsReal = builder.createCast(floatingCastKind, lhsReal, |
| higherPrecisionElementType); |
| lhsImag = builder.createCast(floatingCastKind, lhsImag, |
| higherPrecisionElementType); |
| rhsReal = builder.createCast(floatingCastKind, rhsReal, |
| higherPrecisionElementType); |
| rhsImag = builder.createCast(floatingCastKind, rhsImag, |
| higherPrecisionElementType); |
| |
| mlir::Value algebraicResult = buildAlgebraicComplexDiv( |
| builder, loc, lhsReal, lhsImag, rhsReal, rhsImag); |
| |
| mlir::Value resultReal = builder.createComplexReal(loc, algebraicResult); |
| mlir::Value resultImag = builder.createComplexImag(loc, algebraicResult); |
| |
| mlir::Value finalReal = |
| builder.createCast(floatingCastKind, resultReal, originalElementType); |
| mlir::Value finalImag = |
| builder.createCast(floatingCastKind, resultImag, originalElementType); |
| return builder.createComplexCreate(loc, finalReal, finalImag); |
| } |
| } |
| |
| return buildAlgebraicComplexDiv(builder, loc, lhsReal, lhsImag, rhsReal, |
| rhsImag); |
| } |
| |
| void LoweringPreparePass::lowerComplexDivOp(cir::ComplexDivOp op) { |
| cir::CIRBaseBuilderTy builder(getContext()); |
| builder.setInsertionPointAfter(op); |
| mlir::Location loc = op.getLoc(); |
| mlir::TypedValue<cir::ComplexType> lhs = op.getLhs(); |
| mlir::TypedValue<cir::ComplexType> rhs = op.getRhs(); |
| mlir::Value lhsReal = builder.createComplexReal(loc, lhs); |
| mlir::Value lhsImag = builder.createComplexImag(loc, lhs); |
| mlir::Value rhsReal = builder.createComplexReal(loc, rhs); |
| mlir::Value rhsImag = builder.createComplexImag(loc, rhs); |
| |
| mlir::Value loweredResult = |
| lowerComplexDiv(*this, builder, loc, op, lhsReal, lhsImag, rhsReal, |
| rhsImag, getContext(), *astCtx); |
| op.replaceAllUsesWith(loweredResult); |
| op.erase(); |
| } |
| |
| static llvm::StringRef |
| getComplexMulLibCallName(llvm::APFloat::Semantics semantics) { |
| switch (semantics) { |
| case llvm::APFloat::S_IEEEhalf: |
| return "__mulhc3"; |
| case llvm::APFloat::S_IEEEsingle: |
| return "__mulsc3"; |
| case llvm::APFloat::S_IEEEdouble: |
| return "__muldc3"; |
| case llvm::APFloat::S_PPCDoubleDouble: |
| return "__multc3"; |
| case llvm::APFloat::S_x87DoubleExtended: |
| return "__mulxc3"; |
| case llvm::APFloat::S_IEEEquad: |
| return "__multc3"; |
| default: |
| llvm_unreachable("unsupported floating point type"); |
| } |
| } |
| |
| static mlir::Value lowerComplexMul(LoweringPreparePass &pass, |
| CIRBaseBuilderTy &builder, |
| mlir::Location loc, cir::ComplexMulOp op, |
| mlir::Value lhsReal, mlir::Value lhsImag, |
| mlir::Value rhsReal, mlir::Value rhsImag) { |
| // (a+bi) * (c+di) = (ac-bd) + (ad+bc)i |
| mlir::Value resultRealLhs = |
| builder.createBinop(loc, lhsReal, cir::BinOpKind::Mul, rhsReal); |
| mlir::Value resultRealRhs = |
| builder.createBinop(loc, lhsImag, cir::BinOpKind::Mul, rhsImag); |
| mlir::Value resultImagLhs = |
| builder.createBinop(loc, lhsReal, cir::BinOpKind::Mul, rhsImag); |
| mlir::Value resultImagRhs = |
| builder.createBinop(loc, lhsImag, cir::BinOpKind::Mul, rhsReal); |
| mlir::Value resultReal = builder.createBinop( |
| loc, resultRealLhs, cir::BinOpKind::Sub, resultRealRhs); |
| mlir::Value resultImag = builder.createBinop( |
| loc, resultImagLhs, cir::BinOpKind::Add, resultImagRhs); |
| mlir::Value algebraicResult = |
| builder.createComplexCreate(loc, resultReal, resultImag); |
| |
| cir::ComplexType complexTy = op.getType(); |
| cir::ComplexRangeKind rangeKind = op.getRange(); |
| if (mlir::isa<cir::IntType>(complexTy.getElementType()) || |
| rangeKind == cir::ComplexRangeKind::Basic || |
| rangeKind == cir::ComplexRangeKind::Improved || |
| rangeKind == cir::ComplexRangeKind::Promoted) |
| return algebraicResult; |
| |
| assert(!cir::MissingFeatures::fastMathFlags()); |
| |
| // Check whether the real part and the imaginary part of the result are both |
| // NaN. If so, emit a library call to compute the multiplication instead. |
| // We check a value against NaN by comparing the value against itself. |
| mlir::Value resultRealIsNaN = builder.createIsNaN(loc, resultReal); |
| mlir::Value resultImagIsNaN = builder.createIsNaN(loc, resultImag); |
| mlir::Value resultRealAndImagAreNaN = |
| builder.createLogicalAnd(loc, resultRealIsNaN, resultImagIsNaN); |
| |
| return builder |
| .create<cir::TernaryOp>( |
| loc, resultRealAndImagAreNaN, |
| [&](mlir::OpBuilder &, mlir::Location) { |
| mlir::Value libCallResult = buildComplexBinOpLibCall( |
| pass, builder, &getComplexMulLibCallName, loc, complexTy, |
| lhsReal, lhsImag, rhsReal, rhsImag); |
| builder.createYield(loc, libCallResult); |
| }, |
| [&](mlir::OpBuilder &, mlir::Location) { |
| builder.createYield(loc, algebraicResult); |
| }) |
| .getResult(); |
| } |
| |
| void LoweringPreparePass::lowerComplexMulOp(cir::ComplexMulOp op) { |
| cir::CIRBaseBuilderTy builder(getContext()); |
| builder.setInsertionPointAfter(op); |
| mlir::Location loc = op.getLoc(); |
| mlir::TypedValue<cir::ComplexType> lhs = op.getLhs(); |
| mlir::TypedValue<cir::ComplexType> rhs = op.getRhs(); |
| mlir::Value lhsReal = builder.createComplexReal(loc, lhs); |
| mlir::Value lhsImag = builder.createComplexImag(loc, lhs); |
| mlir::Value rhsReal = builder.createComplexReal(loc, rhs); |
| mlir::Value rhsImag = builder.createComplexImag(loc, rhs); |
| mlir::Value loweredResult = lowerComplexMul(*this, builder, loc, op, lhsReal, |
| lhsImag, rhsReal, rhsImag); |
| op.replaceAllUsesWith(loweredResult); |
| op.erase(); |
| } |
| |
| void LoweringPreparePass::lowerUnaryOp(cir::UnaryOp op) { |
| mlir::Type ty = op.getType(); |
| if (!mlir::isa<cir::ComplexType>(ty)) |
| return; |
| |
| mlir::Location loc = op.getLoc(); |
| cir::UnaryOpKind opKind = op.getKind(); |
| |
| CIRBaseBuilderTy builder(getContext()); |
| builder.setInsertionPointAfter(op); |
| |
| mlir::Value operand = op.getInput(); |
| mlir::Value operandReal = builder.createComplexReal(loc, operand); |
| mlir::Value operandImag = builder.createComplexImag(loc, operand); |
| |
| mlir::Value resultReal; |
| mlir::Value resultImag; |
| |
| switch (opKind) { |
| case cir::UnaryOpKind::Inc: |
| case cir::UnaryOpKind::Dec: |
| resultReal = builder.createUnaryOp(loc, opKind, operandReal); |
| resultImag = operandImag; |
| break; |
| |
| case cir::UnaryOpKind::Plus: |
| case cir::UnaryOpKind::Minus: |
| resultReal = builder.createUnaryOp(loc, opKind, operandReal); |
| resultImag = builder.createUnaryOp(loc, opKind, operandImag); |
| break; |
| |
| case cir::UnaryOpKind::Not: |
| resultReal = operandReal; |
| resultImag = |
| builder.createUnaryOp(loc, cir::UnaryOpKind::Minus, operandImag); |
| break; |
| } |
| |
| mlir::Value result = builder.createComplexCreate(loc, resultReal, resultImag); |
| op.replaceAllUsesWith(result); |
| op.erase(); |
| } |
| |
| static void lowerArrayDtorCtorIntoLoop(cir::CIRBaseBuilderTy &builder, |
| clang::ASTContext *astCtx, |
| mlir::Operation *op, mlir::Type eltTy, |
| mlir::Value arrayAddr, uint64_t arrayLen, |
| bool isCtor) { |
| // Generate loop to call into ctor/dtor for every element. |
| mlir::Location loc = op->getLoc(); |
| |
| // TODO: instead of getting the size from the AST context, create alias for |
| // PtrDiffTy and unify with CIRGen stuff. |
| const unsigned sizeTypeSize = |
| astCtx->getTypeSize(astCtx->getSignedSizeType()); |
| uint64_t endOffset = isCtor ? arrayLen : arrayLen - 1; |
| mlir::Value endOffsetVal = |
| builder.getUnsignedInt(loc, endOffset, sizeTypeSize); |
| |
| auto begin = cir::CastOp::create(builder, loc, eltTy, |
| cir::CastKind::array_to_ptrdecay, arrayAddr); |
| mlir::Value end = |
| cir::PtrStrideOp::create(builder, loc, eltTy, begin, endOffsetVal); |
| mlir::Value start = isCtor ? begin : end; |
| mlir::Value stop = isCtor ? end : begin; |
| |
| mlir::Value tmpAddr = builder.createAlloca( |
| loc, /*addr type*/ builder.getPointerTo(eltTy), |
| /*var type*/ eltTy, "__array_idx", builder.getAlignmentAttr(1)); |
| builder.createStore(loc, start, tmpAddr); |
| |
| cir::DoWhileOp loop = builder.createDoWhile( |
| loc, |
| /*condBuilder=*/ |
| [&](mlir::OpBuilder &b, mlir::Location loc) { |
| auto currentElement = b.create<cir::LoadOp>(loc, eltTy, tmpAddr); |
| mlir::Type boolTy = cir::BoolType::get(b.getContext()); |
| auto cmp = builder.create<cir::CmpOp>(loc, boolTy, cir::CmpOpKind::ne, |
| currentElement, stop); |
| builder.createCondition(cmp); |
| }, |
| /*bodyBuilder=*/ |
| [&](mlir::OpBuilder &b, mlir::Location loc) { |
| auto currentElement = b.create<cir::LoadOp>(loc, eltTy, tmpAddr); |
| |
| cir::CallOp ctorCall; |
| op->walk([&](cir::CallOp c) { ctorCall = c; }); |
| assert(ctorCall && "expected ctor call"); |
| |
| // Array elements get constructed in order but destructed in reverse. |
| mlir::Value stride; |
| if (isCtor) |
| stride = builder.getUnsignedInt(loc, 1, sizeTypeSize); |
| else |
| stride = builder.getSignedInt(loc, -1, sizeTypeSize); |
| |
| ctorCall->moveBefore(stride.getDefiningOp()); |
| ctorCall->setOperand(0, currentElement); |
| auto nextElement = cir::PtrStrideOp::create(builder, loc, eltTy, |
| currentElement, stride); |
| |
| // Store the element pointer to the temporary variable |
| builder.createStore(loc, nextElement, tmpAddr); |
| builder.createYield(loc); |
| }); |
| |
| op->replaceAllUsesWith(loop); |
| op->erase(); |
| } |
| |
| void LoweringPreparePass::lowerArrayDtor(cir::ArrayDtor op) { |
| CIRBaseBuilderTy builder(getContext()); |
| builder.setInsertionPointAfter(op.getOperation()); |
| |
| mlir::Type eltTy = op->getRegion(0).getArgument(0).getType(); |
| assert(!cir::MissingFeatures::vlas()); |
| auto arrayLen = |
| mlir::cast<cir::ArrayType>(op.getAddr().getType().getPointee()).getSize(); |
| lowerArrayDtorCtorIntoLoop(builder, astCtx, op, eltTy, op.getAddr(), arrayLen, |
| false); |
| } |
| |
| void LoweringPreparePass::lowerArrayCtor(cir::ArrayCtor op) { |
| cir::CIRBaseBuilderTy builder(getContext()); |
| builder.setInsertionPointAfter(op.getOperation()); |
| |
| mlir::Type eltTy = op->getRegion(0).getArgument(0).getType(); |
| assert(!cir::MissingFeatures::vlas()); |
| auto arrayLen = |
| mlir::cast<cir::ArrayType>(op.getAddr().getType().getPointee()).getSize(); |
| lowerArrayDtorCtorIntoLoop(builder, astCtx, op, eltTy, op.getAddr(), arrayLen, |
| true); |
| } |
| |
| void LoweringPreparePass::runOnOp(mlir::Operation *op) { |
| if (auto arrayCtor = dyn_cast<ArrayCtor>(op)) |
| lowerArrayCtor(arrayCtor); |
| else if (auto arrayDtor = dyn_cast<cir::ArrayDtor>(op)) |
| lowerArrayDtor(arrayDtor); |
| else if (auto cast = mlir::dyn_cast<cir::CastOp>(op)) |
| lowerCastOp(cast); |
| else if (auto complexDiv = mlir::dyn_cast<cir::ComplexDivOp>(op)) |
| lowerComplexDivOp(complexDiv); |
| else if (auto complexMul = mlir::dyn_cast<cir::ComplexMulOp>(op)) |
| lowerComplexMulOp(complexMul); |
| else if (auto unary = mlir::dyn_cast<cir::UnaryOp>(op)) |
| lowerUnaryOp(unary); |
| } |
| |
| void LoweringPreparePass::runOnOperation() { |
| mlir::Operation *op = getOperation(); |
| if (isa<::mlir::ModuleOp>(op)) |
| mlirModule = cast<::mlir::ModuleOp>(op); |
| |
| llvm::SmallVector<mlir::Operation *> opsToTransform; |
| |
| op->walk([&](mlir::Operation *op) { |
| if (mlir::isa<cir::ArrayCtor, cir::ArrayDtor, cir::CastOp, |
| cir::ComplexMulOp, cir::ComplexDivOp, cir::UnaryOp>(op)) |
| opsToTransform.push_back(op); |
| }); |
| |
| for (mlir::Operation *o : opsToTransform) |
| runOnOp(o); |
| } |
| |
| std::unique_ptr<Pass> mlir::createLoweringPreparePass() { |
| return std::make_unique<LoweringPreparePass>(); |
| } |
| |
| std::unique_ptr<Pass> |
| mlir::createLoweringPreparePass(clang::ASTContext *astCtx) { |
| auto pass = std::make_unique<LoweringPreparePass>(); |
| pass->setASTContext(astCtx); |
| return std::move(pass); |
| } |