blob: c15637d297cd14ae53874e212ba0a26af7cc4f9c [file] [log] [blame]
//===- LoweringPrepare.cpp - pareparation work for LLVM lowering ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "clang/AST/ASTContext.h"
#include "clang/Basic/TargetInfo.h"
#include "clang/CIR/Dialect/Builder/CIRBaseBuilder.h"
#include "clang/CIR/Dialect/IR/CIRDialect.h"
#include "clang/CIR/Dialect/IR/CIROpsEnums.h"
#include "clang/CIR/Dialect/Passes.h"
#include "clang/CIR/MissingFeatures.h"
#include <memory>
using namespace mlir;
using namespace cir;
namespace {
struct LoweringPreparePass : public LoweringPrepareBase<LoweringPreparePass> {
LoweringPreparePass() = default;
void runOnOperation() override;
void runOnOp(mlir::Operation *op);
void lowerCastOp(cir::CastOp op);
void lowerComplexDivOp(cir::ComplexDivOp op);
void lowerComplexMulOp(cir::ComplexMulOp op);
void lowerUnaryOp(cir::UnaryOp op);
void lowerArrayDtor(cir::ArrayDtor op);
void lowerArrayCtor(cir::ArrayCtor op);
cir::FuncOp buildRuntimeFunction(
mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc,
cir::FuncType type,
cir::GlobalLinkageKind linkage = cir::GlobalLinkageKind::ExternalLinkage);
///
/// AST related
/// -----------
clang::ASTContext *astCtx;
/// Tracks current module.
mlir::ModuleOp mlirModule;
void setASTContext(clang::ASTContext *c) { astCtx = c; }
};
} // namespace
cir::FuncOp LoweringPreparePass::buildRuntimeFunction(
mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc,
cir::FuncType type, cir::GlobalLinkageKind linkage) {
cir::FuncOp f = dyn_cast_or_null<FuncOp>(SymbolTable::lookupNearestSymbolFrom(
mlirModule, StringAttr::get(mlirModule->getContext(), name)));
if (!f) {
f = builder.create<cir::FuncOp>(loc, name, type);
f.setLinkageAttr(
cir::GlobalLinkageKindAttr::get(builder.getContext(), linkage));
mlir::SymbolTable::setSymbolVisibility(
f, mlir::SymbolTable::Visibility::Private);
assert(!cir::MissingFeatures::opFuncExtraAttrs());
}
return f;
}
static mlir::Value lowerScalarToComplexCast(mlir::MLIRContext &ctx,
cir::CastOp op) {
cir::CIRBaseBuilderTy builder(ctx);
builder.setInsertionPoint(op);
mlir::Value src = op.getSrc();
mlir::Value imag = builder.getNullValue(src.getType(), op.getLoc());
return builder.createComplexCreate(op.getLoc(), src, imag);
}
static mlir::Value lowerComplexToScalarCast(mlir::MLIRContext &ctx,
cir::CastOp op,
cir::CastKind elemToBoolKind) {
cir::CIRBaseBuilderTy builder(ctx);
builder.setInsertionPoint(op);
mlir::Value src = op.getSrc();
if (!mlir::isa<cir::BoolType>(op.getType()))
return builder.createComplexReal(op.getLoc(), src);
// Complex cast to bool: (bool)(a+bi) => (bool)a || (bool)b
mlir::Value srcReal = builder.createComplexReal(op.getLoc(), src);
mlir::Value srcImag = builder.createComplexImag(op.getLoc(), src);
cir::BoolType boolTy = builder.getBoolTy();
mlir::Value srcRealToBool =
builder.createCast(op.getLoc(), elemToBoolKind, srcReal, boolTy);
mlir::Value srcImagToBool =
builder.createCast(op.getLoc(), elemToBoolKind, srcImag, boolTy);
return builder.createLogicalOr(op.getLoc(), srcRealToBool, srcImagToBool);
}
static mlir::Value lowerComplexToComplexCast(mlir::MLIRContext &ctx,
cir::CastOp op,
cir::CastKind scalarCastKind) {
CIRBaseBuilderTy builder(ctx);
builder.setInsertionPoint(op);
mlir::Value src = op.getSrc();
auto dstComplexElemTy =
mlir::cast<cir::ComplexType>(op.getType()).getElementType();
mlir::Value srcReal = builder.createComplexReal(op.getLoc(), src);
mlir::Value srcImag = builder.createComplexImag(op.getLoc(), src);
mlir::Value dstReal = builder.createCast(op.getLoc(), scalarCastKind, srcReal,
dstComplexElemTy);
mlir::Value dstImag = builder.createCast(op.getLoc(), scalarCastKind, srcImag,
dstComplexElemTy);
return builder.createComplexCreate(op.getLoc(), dstReal, dstImag);
}
void LoweringPreparePass::lowerCastOp(cir::CastOp op) {
mlir::MLIRContext &ctx = getContext();
mlir::Value loweredValue = [&]() -> mlir::Value {
switch (op.getKind()) {
case cir::CastKind::float_to_complex:
case cir::CastKind::int_to_complex:
return lowerScalarToComplexCast(ctx, op);
case cir::CastKind::float_complex_to_real:
case cir::CastKind::int_complex_to_real:
return lowerComplexToScalarCast(ctx, op, op.getKind());
case cir::CastKind::float_complex_to_bool:
return lowerComplexToScalarCast(ctx, op, cir::CastKind::float_to_bool);
case cir::CastKind::int_complex_to_bool:
return lowerComplexToScalarCast(ctx, op, cir::CastKind::int_to_bool);
case cir::CastKind::float_complex:
return lowerComplexToComplexCast(ctx, op, cir::CastKind::floating);
case cir::CastKind::float_complex_to_int_complex:
return lowerComplexToComplexCast(ctx, op, cir::CastKind::float_to_int);
case cir::CastKind::int_complex:
return lowerComplexToComplexCast(ctx, op, cir::CastKind::integral);
case cir::CastKind::int_complex_to_float_complex:
return lowerComplexToComplexCast(ctx, op, cir::CastKind::int_to_float);
default:
return nullptr;
}
}();
if (loweredValue) {
op.replaceAllUsesWith(loweredValue);
op.erase();
}
}
static mlir::Value buildComplexBinOpLibCall(
LoweringPreparePass &pass, CIRBaseBuilderTy &builder,
llvm::StringRef (*libFuncNameGetter)(llvm::APFloat::Semantics),
mlir::Location loc, cir::ComplexType ty, mlir::Value lhsReal,
mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag) {
cir::FPTypeInterface elementTy =
mlir::cast<cir::FPTypeInterface>(ty.getElementType());
llvm::StringRef libFuncName = libFuncNameGetter(
llvm::APFloat::SemanticsToEnum(elementTy.getFloatSemantics()));
llvm::SmallVector<mlir::Type, 4> libFuncInputTypes(4, elementTy);
cir::FuncType libFuncTy = cir::FuncType::get(libFuncInputTypes, ty);
// Insert a declaration for the runtime function to be used in Complex
// multiplication and division when needed
cir::FuncOp libFunc;
{
mlir::OpBuilder::InsertionGuard ipGuard{builder};
builder.setInsertionPointToStart(pass.mlirModule.getBody());
libFunc = pass.buildRuntimeFunction(builder, libFuncName, loc, libFuncTy);
}
cir::CallOp call =
builder.createCallOp(loc, libFunc, {lhsReal, lhsImag, rhsReal, rhsImag});
return call.getResult();
}
static llvm::StringRef
getComplexDivLibCallName(llvm::APFloat::Semantics semantics) {
switch (semantics) {
case llvm::APFloat::S_IEEEhalf:
return "__divhc3";
case llvm::APFloat::S_IEEEsingle:
return "__divsc3";
case llvm::APFloat::S_IEEEdouble:
return "__divdc3";
case llvm::APFloat::S_PPCDoubleDouble:
return "__divtc3";
case llvm::APFloat::S_x87DoubleExtended:
return "__divxc3";
case llvm::APFloat::S_IEEEquad:
return "__divtc3";
default:
llvm_unreachable("unsupported floating point type");
}
}
static mlir::Value
buildAlgebraicComplexDiv(CIRBaseBuilderTy &builder, mlir::Location loc,
mlir::Value lhsReal, mlir::Value lhsImag,
mlir::Value rhsReal, mlir::Value rhsImag) {
// (a+bi) / (c+di) = ((ac+bd)/(cc+dd)) + ((bc-ad)/(cc+dd))i
mlir::Value &a = lhsReal;
mlir::Value &b = lhsImag;
mlir::Value &c = rhsReal;
mlir::Value &d = rhsImag;
mlir::Value ac = builder.createBinop(loc, a, cir::BinOpKind::Mul, c); // a*c
mlir::Value bd = builder.createBinop(loc, b, cir::BinOpKind::Mul, d); // b*d
mlir::Value cc = builder.createBinop(loc, c, cir::BinOpKind::Mul, c); // c*c
mlir::Value dd = builder.createBinop(loc, d, cir::BinOpKind::Mul, d); // d*d
mlir::Value acbd =
builder.createBinop(loc, ac, cir::BinOpKind::Add, bd); // ac+bd
mlir::Value ccdd =
builder.createBinop(loc, cc, cir::BinOpKind::Add, dd); // cc+dd
mlir::Value resultReal =
builder.createBinop(loc, acbd, cir::BinOpKind::Div, ccdd);
mlir::Value bc = builder.createBinop(loc, b, cir::BinOpKind::Mul, c); // b*c
mlir::Value ad = builder.createBinop(loc, a, cir::BinOpKind::Mul, d); // a*d
mlir::Value bcad =
builder.createBinop(loc, bc, cir::BinOpKind::Sub, ad); // bc-ad
mlir::Value resultImag =
builder.createBinop(loc, bcad, cir::BinOpKind::Div, ccdd);
return builder.createComplexCreate(loc, resultReal, resultImag);
}
static mlir::Value
buildRangeReductionComplexDiv(CIRBaseBuilderTy &builder, mlir::Location loc,
mlir::Value lhsReal, mlir::Value lhsImag,
mlir::Value rhsReal, mlir::Value rhsImag) {
// Implements Smith's algorithm for complex division.
// SMITH, R. L. Algorithm 116: Complex division. Commun. ACM 5, 8 (1962).
// Let:
// - lhs := a+bi
// - rhs := c+di
// - result := lhs / rhs = e+fi
//
// The algorithm pseudocode looks like follows:
// if fabs(c) >= fabs(d):
// r := d / c
// tmp := c + r*d
// e = (a + b*r) / tmp
// f = (b - a*r) / tmp
// else:
// r := c / d
// tmp := d + r*c
// e = (a*r + b) / tmp
// f = (b*r - a) / tmp
mlir::Value &a = lhsReal;
mlir::Value &b = lhsImag;
mlir::Value &c = rhsReal;
mlir::Value &d = rhsImag;
auto trueBranchBuilder = [&](mlir::OpBuilder &, mlir::Location) {
mlir::Value r = builder.createBinop(loc, d, cir::BinOpKind::Div,
c); // r := d / c
mlir::Value rd = builder.createBinop(loc, r, cir::BinOpKind::Mul, d); // r*d
mlir::Value tmp = builder.createBinop(loc, c, cir::BinOpKind::Add,
rd); // tmp := c + r*d
mlir::Value br = builder.createBinop(loc, b, cir::BinOpKind::Mul, r); // b*r
mlir::Value abr =
builder.createBinop(loc, a, cir::BinOpKind::Add, br); // a + b*r
mlir::Value e = builder.createBinop(loc, abr, cir::BinOpKind::Div, tmp);
mlir::Value ar = builder.createBinop(loc, a, cir::BinOpKind::Mul, r); // a*r
mlir::Value bar =
builder.createBinop(loc, b, cir::BinOpKind::Sub, ar); // b - a*r
mlir::Value f = builder.createBinop(loc, bar, cir::BinOpKind::Div, tmp);
mlir::Value result = builder.createComplexCreate(loc, e, f);
builder.createYield(loc, result);
};
auto falseBranchBuilder = [&](mlir::OpBuilder &, mlir::Location) {
mlir::Value r = builder.createBinop(loc, c, cir::BinOpKind::Div,
d); // r := c / d
mlir::Value rc = builder.createBinop(loc, r, cir::BinOpKind::Mul, c); // r*c
mlir::Value tmp = builder.createBinop(loc, d, cir::BinOpKind::Add,
rc); // tmp := d + r*c
mlir::Value ar = builder.createBinop(loc, a, cir::BinOpKind::Mul, r); // a*r
mlir::Value arb =
builder.createBinop(loc, ar, cir::BinOpKind::Add, b); // a*r + b
mlir::Value e = builder.createBinop(loc, arb, cir::BinOpKind::Div, tmp);
mlir::Value br = builder.createBinop(loc, b, cir::BinOpKind::Mul, r); // b*r
mlir::Value bra =
builder.createBinop(loc, br, cir::BinOpKind::Sub, a); // b*r - a
mlir::Value f = builder.createBinop(loc, bra, cir::BinOpKind::Div, tmp);
mlir::Value result = builder.createComplexCreate(loc, e, f);
builder.createYield(loc, result);
};
auto cFabs = builder.create<cir::FAbsOp>(loc, c);
auto dFabs = builder.create<cir::FAbsOp>(loc, d);
cir::CmpOp cmpResult =
builder.createCompare(loc, cir::CmpOpKind::ge, cFabs, dFabs);
auto ternary = builder.create<cir::TernaryOp>(
loc, cmpResult, trueBranchBuilder, falseBranchBuilder);
return ternary.getResult();
}
static mlir::Type higherPrecisionElementTypeForComplexArithmetic(
mlir::MLIRContext &context, clang::ASTContext &cc,
CIRBaseBuilderTy &builder, mlir::Type elementType) {
auto getHigherPrecisionFPType = [&context](mlir::Type type) -> mlir::Type {
if (mlir::isa<cir::FP16Type>(type))
return cir::SingleType::get(&context);
if (mlir::isa<cir::SingleType>(type) || mlir::isa<cir::BF16Type>(type))
return cir::DoubleType::get(&context);
if (mlir::isa<cir::DoubleType>(type))
return cir::LongDoubleType::get(&context, type);
return type;
};
auto getFloatTypeSemantics =
[&cc](mlir::Type type) -> const llvm::fltSemantics & {
const clang::TargetInfo &info = cc.getTargetInfo();
if (mlir::isa<cir::FP16Type>(type))
return info.getHalfFormat();
if (mlir::isa<cir::BF16Type>(type))
return info.getBFloat16Format();
if (mlir::isa<cir::SingleType>(type))
return info.getFloatFormat();
if (mlir::isa<cir::DoubleType>(type))
return info.getDoubleFormat();
if (mlir::isa<cir::LongDoubleType>(type)) {
if (cc.getLangOpts().OpenMP && cc.getLangOpts().OpenMPIsTargetDevice)
llvm_unreachable("NYI Float type semantics with OpenMP");
return info.getLongDoubleFormat();
}
if (mlir::isa<cir::FP128Type>(type)) {
if (cc.getLangOpts().OpenMP && cc.getLangOpts().OpenMPIsTargetDevice)
llvm_unreachable("NYI Float type semantics with OpenMP");
return info.getFloat128Format();
}
assert(false && "Unsupported float type semantics");
};
const mlir::Type higherElementType = getHigherPrecisionFPType(elementType);
const llvm::fltSemantics &elementTypeSemantics =
getFloatTypeSemantics(elementType);
const llvm::fltSemantics &higherElementTypeSemantics =
getFloatTypeSemantics(higherElementType);
// Check that the promoted type can handle the intermediate values without
// overflowing. This can be interpreted as:
// (SmallerType.LargestFiniteVal * SmallerType.LargestFiniteVal) * 2 <=
// LargerType.LargestFiniteVal.
// In terms of exponent it gives this formula:
// (SmallerType.LargestFiniteVal * SmallerType.LargestFiniteVal
// doubles the exponent of SmallerType.LargestFiniteVal)
if (llvm::APFloat::semanticsMaxExponent(elementTypeSemantics) * 2 + 1 <=
llvm::APFloat::semanticsMaxExponent(higherElementTypeSemantics)) {
return higherElementType;
}
// The intermediate values can't be represented in the promoted type
// without overflowing.
return {};
}
static mlir::Value
lowerComplexDiv(LoweringPreparePass &pass, CIRBaseBuilderTy &builder,
mlir::Location loc, cir::ComplexDivOp op, mlir::Value lhsReal,
mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag,
mlir::MLIRContext &mlirCx, clang::ASTContext &cc) {
cir::ComplexType complexTy = op.getType();
if (mlir::isa<cir::FPTypeInterface>(complexTy.getElementType())) {
cir::ComplexRangeKind range = op.getRange();
if (range == cir::ComplexRangeKind::Improved)
return buildRangeReductionComplexDiv(builder, loc, lhsReal, lhsImag,
rhsReal, rhsImag);
if (range == cir::ComplexRangeKind::Full)
return buildComplexBinOpLibCall(pass, builder, &getComplexDivLibCallName,
loc, complexTy, lhsReal, lhsImag, rhsReal,
rhsImag);
if (range == cir::ComplexRangeKind::Promoted) {
mlir::Type originalElementType = complexTy.getElementType();
mlir::Type higherPrecisionElementType =
higherPrecisionElementTypeForComplexArithmetic(mlirCx, cc, builder,
originalElementType);
if (!higherPrecisionElementType)
return buildRangeReductionComplexDiv(builder, loc, lhsReal, lhsImag,
rhsReal, rhsImag);
cir::CastKind floatingCastKind = cir::CastKind::floating;
lhsReal = builder.createCast(floatingCastKind, lhsReal,
higherPrecisionElementType);
lhsImag = builder.createCast(floatingCastKind, lhsImag,
higherPrecisionElementType);
rhsReal = builder.createCast(floatingCastKind, rhsReal,
higherPrecisionElementType);
rhsImag = builder.createCast(floatingCastKind, rhsImag,
higherPrecisionElementType);
mlir::Value algebraicResult = buildAlgebraicComplexDiv(
builder, loc, lhsReal, lhsImag, rhsReal, rhsImag);
mlir::Value resultReal = builder.createComplexReal(loc, algebraicResult);
mlir::Value resultImag = builder.createComplexImag(loc, algebraicResult);
mlir::Value finalReal =
builder.createCast(floatingCastKind, resultReal, originalElementType);
mlir::Value finalImag =
builder.createCast(floatingCastKind, resultImag, originalElementType);
return builder.createComplexCreate(loc, finalReal, finalImag);
}
}
return buildAlgebraicComplexDiv(builder, loc, lhsReal, lhsImag, rhsReal,
rhsImag);
}
void LoweringPreparePass::lowerComplexDivOp(cir::ComplexDivOp op) {
cir::CIRBaseBuilderTy builder(getContext());
builder.setInsertionPointAfter(op);
mlir::Location loc = op.getLoc();
mlir::TypedValue<cir::ComplexType> lhs = op.getLhs();
mlir::TypedValue<cir::ComplexType> rhs = op.getRhs();
mlir::Value lhsReal = builder.createComplexReal(loc, lhs);
mlir::Value lhsImag = builder.createComplexImag(loc, lhs);
mlir::Value rhsReal = builder.createComplexReal(loc, rhs);
mlir::Value rhsImag = builder.createComplexImag(loc, rhs);
mlir::Value loweredResult =
lowerComplexDiv(*this, builder, loc, op, lhsReal, lhsImag, rhsReal,
rhsImag, getContext(), *astCtx);
op.replaceAllUsesWith(loweredResult);
op.erase();
}
static llvm::StringRef
getComplexMulLibCallName(llvm::APFloat::Semantics semantics) {
switch (semantics) {
case llvm::APFloat::S_IEEEhalf:
return "__mulhc3";
case llvm::APFloat::S_IEEEsingle:
return "__mulsc3";
case llvm::APFloat::S_IEEEdouble:
return "__muldc3";
case llvm::APFloat::S_PPCDoubleDouble:
return "__multc3";
case llvm::APFloat::S_x87DoubleExtended:
return "__mulxc3";
case llvm::APFloat::S_IEEEquad:
return "__multc3";
default:
llvm_unreachable("unsupported floating point type");
}
}
static mlir::Value lowerComplexMul(LoweringPreparePass &pass,
CIRBaseBuilderTy &builder,
mlir::Location loc, cir::ComplexMulOp op,
mlir::Value lhsReal, mlir::Value lhsImag,
mlir::Value rhsReal, mlir::Value rhsImag) {
// (a+bi) * (c+di) = (ac-bd) + (ad+bc)i
mlir::Value resultRealLhs =
builder.createBinop(loc, lhsReal, cir::BinOpKind::Mul, rhsReal);
mlir::Value resultRealRhs =
builder.createBinop(loc, lhsImag, cir::BinOpKind::Mul, rhsImag);
mlir::Value resultImagLhs =
builder.createBinop(loc, lhsReal, cir::BinOpKind::Mul, rhsImag);
mlir::Value resultImagRhs =
builder.createBinop(loc, lhsImag, cir::BinOpKind::Mul, rhsReal);
mlir::Value resultReal = builder.createBinop(
loc, resultRealLhs, cir::BinOpKind::Sub, resultRealRhs);
mlir::Value resultImag = builder.createBinop(
loc, resultImagLhs, cir::BinOpKind::Add, resultImagRhs);
mlir::Value algebraicResult =
builder.createComplexCreate(loc, resultReal, resultImag);
cir::ComplexType complexTy = op.getType();
cir::ComplexRangeKind rangeKind = op.getRange();
if (mlir::isa<cir::IntType>(complexTy.getElementType()) ||
rangeKind == cir::ComplexRangeKind::Basic ||
rangeKind == cir::ComplexRangeKind::Improved ||
rangeKind == cir::ComplexRangeKind::Promoted)
return algebraicResult;
assert(!cir::MissingFeatures::fastMathFlags());
// Check whether the real part and the imaginary part of the result are both
// NaN. If so, emit a library call to compute the multiplication instead.
// We check a value against NaN by comparing the value against itself.
mlir::Value resultRealIsNaN = builder.createIsNaN(loc, resultReal);
mlir::Value resultImagIsNaN = builder.createIsNaN(loc, resultImag);
mlir::Value resultRealAndImagAreNaN =
builder.createLogicalAnd(loc, resultRealIsNaN, resultImagIsNaN);
return builder
.create<cir::TernaryOp>(
loc, resultRealAndImagAreNaN,
[&](mlir::OpBuilder &, mlir::Location) {
mlir::Value libCallResult = buildComplexBinOpLibCall(
pass, builder, &getComplexMulLibCallName, loc, complexTy,
lhsReal, lhsImag, rhsReal, rhsImag);
builder.createYield(loc, libCallResult);
},
[&](mlir::OpBuilder &, mlir::Location) {
builder.createYield(loc, algebraicResult);
})
.getResult();
}
void LoweringPreparePass::lowerComplexMulOp(cir::ComplexMulOp op) {
cir::CIRBaseBuilderTy builder(getContext());
builder.setInsertionPointAfter(op);
mlir::Location loc = op.getLoc();
mlir::TypedValue<cir::ComplexType> lhs = op.getLhs();
mlir::TypedValue<cir::ComplexType> rhs = op.getRhs();
mlir::Value lhsReal = builder.createComplexReal(loc, lhs);
mlir::Value lhsImag = builder.createComplexImag(loc, lhs);
mlir::Value rhsReal = builder.createComplexReal(loc, rhs);
mlir::Value rhsImag = builder.createComplexImag(loc, rhs);
mlir::Value loweredResult = lowerComplexMul(*this, builder, loc, op, lhsReal,
lhsImag, rhsReal, rhsImag);
op.replaceAllUsesWith(loweredResult);
op.erase();
}
void LoweringPreparePass::lowerUnaryOp(cir::UnaryOp op) {
mlir::Type ty = op.getType();
if (!mlir::isa<cir::ComplexType>(ty))
return;
mlir::Location loc = op.getLoc();
cir::UnaryOpKind opKind = op.getKind();
CIRBaseBuilderTy builder(getContext());
builder.setInsertionPointAfter(op);
mlir::Value operand = op.getInput();
mlir::Value operandReal = builder.createComplexReal(loc, operand);
mlir::Value operandImag = builder.createComplexImag(loc, operand);
mlir::Value resultReal;
mlir::Value resultImag;
switch (opKind) {
case cir::UnaryOpKind::Inc:
case cir::UnaryOpKind::Dec:
resultReal = builder.createUnaryOp(loc, opKind, operandReal);
resultImag = operandImag;
break;
case cir::UnaryOpKind::Plus:
case cir::UnaryOpKind::Minus:
resultReal = builder.createUnaryOp(loc, opKind, operandReal);
resultImag = builder.createUnaryOp(loc, opKind, operandImag);
break;
case cir::UnaryOpKind::Not:
resultReal = operandReal;
resultImag =
builder.createUnaryOp(loc, cir::UnaryOpKind::Minus, operandImag);
break;
}
mlir::Value result = builder.createComplexCreate(loc, resultReal, resultImag);
op.replaceAllUsesWith(result);
op.erase();
}
static void lowerArrayDtorCtorIntoLoop(cir::CIRBaseBuilderTy &builder,
clang::ASTContext *astCtx,
mlir::Operation *op, mlir::Type eltTy,
mlir::Value arrayAddr, uint64_t arrayLen,
bool isCtor) {
// Generate loop to call into ctor/dtor for every element.
mlir::Location loc = op->getLoc();
// TODO: instead of getting the size from the AST context, create alias for
// PtrDiffTy and unify with CIRGen stuff.
const unsigned sizeTypeSize =
astCtx->getTypeSize(astCtx->getSignedSizeType());
uint64_t endOffset = isCtor ? arrayLen : arrayLen - 1;
mlir::Value endOffsetVal =
builder.getUnsignedInt(loc, endOffset, sizeTypeSize);
auto begin = cir::CastOp::create(builder, loc, eltTy,
cir::CastKind::array_to_ptrdecay, arrayAddr);
mlir::Value end =
cir::PtrStrideOp::create(builder, loc, eltTy, begin, endOffsetVal);
mlir::Value start = isCtor ? begin : end;
mlir::Value stop = isCtor ? end : begin;
mlir::Value tmpAddr = builder.createAlloca(
loc, /*addr type*/ builder.getPointerTo(eltTy),
/*var type*/ eltTy, "__array_idx", builder.getAlignmentAttr(1));
builder.createStore(loc, start, tmpAddr);
cir::DoWhileOp loop = builder.createDoWhile(
loc,
/*condBuilder=*/
[&](mlir::OpBuilder &b, mlir::Location loc) {
auto currentElement = b.create<cir::LoadOp>(loc, eltTy, tmpAddr);
mlir::Type boolTy = cir::BoolType::get(b.getContext());
auto cmp = builder.create<cir::CmpOp>(loc, boolTy, cir::CmpOpKind::ne,
currentElement, stop);
builder.createCondition(cmp);
},
/*bodyBuilder=*/
[&](mlir::OpBuilder &b, mlir::Location loc) {
auto currentElement = b.create<cir::LoadOp>(loc, eltTy, tmpAddr);
cir::CallOp ctorCall;
op->walk([&](cir::CallOp c) { ctorCall = c; });
assert(ctorCall && "expected ctor call");
// Array elements get constructed in order but destructed in reverse.
mlir::Value stride;
if (isCtor)
stride = builder.getUnsignedInt(loc, 1, sizeTypeSize);
else
stride = builder.getSignedInt(loc, -1, sizeTypeSize);
ctorCall->moveBefore(stride.getDefiningOp());
ctorCall->setOperand(0, currentElement);
auto nextElement = cir::PtrStrideOp::create(builder, loc, eltTy,
currentElement, stride);
// Store the element pointer to the temporary variable
builder.createStore(loc, nextElement, tmpAddr);
builder.createYield(loc);
});
op->replaceAllUsesWith(loop);
op->erase();
}
void LoweringPreparePass::lowerArrayDtor(cir::ArrayDtor op) {
CIRBaseBuilderTy builder(getContext());
builder.setInsertionPointAfter(op.getOperation());
mlir::Type eltTy = op->getRegion(0).getArgument(0).getType();
assert(!cir::MissingFeatures::vlas());
auto arrayLen =
mlir::cast<cir::ArrayType>(op.getAddr().getType().getPointee()).getSize();
lowerArrayDtorCtorIntoLoop(builder, astCtx, op, eltTy, op.getAddr(), arrayLen,
false);
}
void LoweringPreparePass::lowerArrayCtor(cir::ArrayCtor op) {
cir::CIRBaseBuilderTy builder(getContext());
builder.setInsertionPointAfter(op.getOperation());
mlir::Type eltTy = op->getRegion(0).getArgument(0).getType();
assert(!cir::MissingFeatures::vlas());
auto arrayLen =
mlir::cast<cir::ArrayType>(op.getAddr().getType().getPointee()).getSize();
lowerArrayDtorCtorIntoLoop(builder, astCtx, op, eltTy, op.getAddr(), arrayLen,
true);
}
void LoweringPreparePass::runOnOp(mlir::Operation *op) {
if (auto arrayCtor = dyn_cast<ArrayCtor>(op))
lowerArrayCtor(arrayCtor);
else if (auto arrayDtor = dyn_cast<cir::ArrayDtor>(op))
lowerArrayDtor(arrayDtor);
else if (auto cast = mlir::dyn_cast<cir::CastOp>(op))
lowerCastOp(cast);
else if (auto complexDiv = mlir::dyn_cast<cir::ComplexDivOp>(op))
lowerComplexDivOp(complexDiv);
else if (auto complexMul = mlir::dyn_cast<cir::ComplexMulOp>(op))
lowerComplexMulOp(complexMul);
else if (auto unary = mlir::dyn_cast<cir::UnaryOp>(op))
lowerUnaryOp(unary);
}
void LoweringPreparePass::runOnOperation() {
mlir::Operation *op = getOperation();
if (isa<::mlir::ModuleOp>(op))
mlirModule = cast<::mlir::ModuleOp>(op);
llvm::SmallVector<mlir::Operation *> opsToTransform;
op->walk([&](mlir::Operation *op) {
if (mlir::isa<cir::ArrayCtor, cir::ArrayDtor, cir::CastOp,
cir::ComplexMulOp, cir::ComplexDivOp, cir::UnaryOp>(op))
opsToTransform.push_back(op);
});
for (mlir::Operation *o : opsToTransform)
runOnOp(o);
}
std::unique_ptr<Pass> mlir::createLoweringPreparePass() {
return std::make_unique<LoweringPreparePass>();
}
std::unique_ptr<Pass>
mlir::createLoweringPreparePass(clang::ASTContext *astCtx) {
auto pass = std::make_unique<LoweringPreparePass>();
pass->setASTContext(astCtx);
return std::move(pass);
}