Revert "Revert "[Flang][OpenMP] NFC: Minor refactoring of Reduction lowering code" (#73139)" This reverts commit c2b3f16fb595fa88bfd21b455785c59ac6a21ed4.
diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp index 4b0582a6..0aaf8f1 100644 --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp
@@ -724,276 +724,438 @@ TODO(location, "OMPD_target_data MapOperand BoxType"); } -static std::string getReductionName(llvm::StringRef name, mlir::Type ty) { - return (llvm::Twine(name) + - (ty.isIntOrIndex() ? llvm::Twine("_i_") : llvm::Twine("_f_")) + - llvm::Twine(ty.getIntOrFloatBitWidth())) - .str(); -} - -static std::string getReductionName( - Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp, - mlir::Type ty) { - std::string reductionName; - - switch (intrinsicOp) { - case Fortran::parser::DefinedOperator::IntrinsicOperator::Add: - reductionName = "add_reduction"; - break; - case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply: - reductionName = "multiply_reduction"; - break; - case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: - return "and_reduction"; - case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: - return "eqv_reduction"; - case Fortran::parser::DefinedOperator::IntrinsicOperator::OR: - return "or_reduction"; - case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV: - return "neqv_reduction"; - default: - reductionName = "other_reduction"; - break; +class ReductionProcessor { +public: + enum IntrinsicProc { MAX, MIN, IAND, IOR, IEOR }; + static IntrinsicProc + getReductionType(const Fortran::parser::ProcedureDesignator &pd) { + auto redType = llvm::StringSwitch<std::optional<IntrinsicProc>>( + getRealName(pd).ToString()) + .Case("max", IntrinsicProc::MAX) + .Case("min", IntrinsicProc::MIN) + .Case("iand", IntrinsicProc::IAND) + .Case("ior", IntrinsicProc::IOR) + .Case("ieor", IntrinsicProc::IEOR) + .Default(std::nullopt); + assert(redType && "Invalid Reduction"); + return *redType; } - return getReductionName(reductionName, ty); -} + static bool supportedIntrinsicProcReduction( + const Fortran::parser::ProcedureDesignator &pd) { + const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(pd)}; + assert(name && "Invalid Reduction Intrinsic."); + auto redType = llvm::StringSwitch<std::optional<IntrinsicProc>>( + getRealName(name).ToString()) + .Case("max", IntrinsicProc::MAX) + .Case("min", IntrinsicProc::MIN) + .Case("iand", IntrinsicProc::IAND) + .Case("ior", IntrinsicProc::IOR) + .Case("ieor", IntrinsicProc::IEOR) + .Default(std::nullopt); + if (redType) + return true; + return false; + } -/// This function returns the identity value of the operator \p reductionOpName. -/// For example: -/// 0 + x = x, -/// 1 * x = x -static int getOperationIdentity(llvm::StringRef reductionOpName, - mlir::Location loc) { - if (reductionOpName.contains("add") || reductionOpName.contains("or") || - reductionOpName.contains("neqv")) - return 0; - if (reductionOpName.contains("multiply") || reductionOpName.contains("and") || - reductionOpName.contains("eqv")) - return 1; - TODO(loc, "Reduction of some intrinsic operators is not supported"); -} + static const Fortran::semantics::SourceName + getRealName(const Fortran::parser::Name *name) { + return name->symbol->GetUltimate().name(); + } -static mlir::Value getReductionInitValue(mlir::Location loc, mlir::Type type, - llvm::StringRef reductionOpName, - fir::FirOpBuilder &builder) { - assert((fir::isa_integer(type) || fir::isa_real(type) || - type.isa<fir::LogicalType>()) && - "only integer, logical and real types are currently supported"); - if (reductionOpName.contains("max")) { - if (auto ty = type.dyn_cast<mlir::FloatType>()) { - const llvm::fltSemantics &sem = ty.getFloatSemantics(); - return builder.createRealConstant( - loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/true)); + static const Fortran::semantics::SourceName + getRealName(const Fortran::parser::ProcedureDesignator &pd) { + const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(pd)}; + assert(name && "Invalid Reduction Intrinsic."); + return getRealName(name); + } + + static std::string getReductionName(llvm::StringRef name, mlir::Type ty) { + return (llvm::Twine(name) + + (ty.isIntOrIndex() ? llvm::Twine("_i_") : llvm::Twine("_f_")) + + llvm::Twine(ty.getIntOrFloatBitWidth())) + .str(); + } + + static std::string getReductionName( + Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp, + mlir::Type ty) { + std::string reductionName; + + switch (intrinsicOp) { + case Fortran::parser::DefinedOperator::IntrinsicOperator::Add: + reductionName = "add_reduction"; + break; + case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply: + reductionName = "multiply_reduction"; + break; + case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: + return "and_reduction"; + case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: + return "eqv_reduction"; + case Fortran::parser::DefinedOperator::IntrinsicOperator::OR: + return "or_reduction"; + case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV: + return "neqv_reduction"; + default: + reductionName = "other_reduction"; + break; } - unsigned bits = type.getIntOrFloatBitWidth(); - int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue(); - return builder.createIntegerConstant(loc, type, minInt); - } else if (reductionOpName.contains("min")) { - if (auto ty = type.dyn_cast<mlir::FloatType>()) { - const llvm::fltSemantics &sem = ty.getFloatSemantics(); - return builder.createRealConstant( - loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/false)); + + return getReductionName(reductionName, ty); + } + + /// This function returns the identity value of the operator \p + /// reductionOpName. For example: + /// 0 + x = x, + /// 1 * x = x + static int getOperationIdentity( + Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp, + mlir::Location loc) { + switch (intrinsicOp) { + case Fortran::parser::DefinedOperator::IntrinsicOperator::Add: + case Fortran::parser::DefinedOperator::IntrinsicOperator::OR: + case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV: + return 0; + case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply: + case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: + case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: + return 1; + default: + TODO(loc, "Reduction of some intrinsic operators is not supported"); } - unsigned bits = type.getIntOrFloatBitWidth(); - int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue(); - return builder.createIntegerConstant(loc, type, maxInt); - } else if (reductionOpName.contains("ior")) { - unsigned bits = type.getIntOrFloatBitWidth(); - int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue(); - return builder.createIntegerConstant(loc, type, zeroInt); - } else if (reductionOpName.contains("ieor")) { - unsigned bits = type.getIntOrFloatBitWidth(); - int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue(); - return builder.createIntegerConstant(loc, type, zeroInt); - } else if (reductionOpName.contains("iand")) { - unsigned bits = type.getIntOrFloatBitWidth(); - int64_t allOnInt = llvm::APInt::getAllOnes(bits).getSExtValue(); - return builder.createIntegerConstant(loc, type, allOnInt); - } else { + } + + static mlir::Value getIntrinsicProcInitValue( + mlir::Location loc, mlir::Type type, + const Fortran::parser::ProcedureDesignator &procDesignator, + fir::FirOpBuilder &builder) { + assert((fir::isa_integer(type) || fir::isa_real(type) || + type.isa<fir::LogicalType>()) && + "only integer, logical and real types are currently supported"); + switch (getReductionType(procDesignator)) { + case IntrinsicProc::MAX: { + if (auto ty = type.dyn_cast<mlir::FloatType>()) { + const llvm::fltSemantics &sem = ty.getFloatSemantics(); + return builder.createRealConstant( + loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/true)); + } + unsigned bits = type.getIntOrFloatBitWidth(); + int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue(); + return builder.createIntegerConstant(loc, type, minInt); + } + case IntrinsicProc::MIN: { + if (auto ty = type.dyn_cast<mlir::FloatType>()) { + const llvm::fltSemantics &sem = ty.getFloatSemantics(); + return builder.createRealConstant( + loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/false)); + } + unsigned bits = type.getIntOrFloatBitWidth(); + int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue(); + return builder.createIntegerConstant(loc, type, maxInt); + } + case IntrinsicProc::IOR: { + unsigned bits = type.getIntOrFloatBitWidth(); + int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue(); + return builder.createIntegerConstant(loc, type, zeroInt); + } + case IntrinsicProc::IEOR: { + unsigned bits = type.getIntOrFloatBitWidth(); + int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue(); + return builder.createIntegerConstant(loc, type, zeroInt); + } + case IntrinsicProc::IAND: { + unsigned bits = type.getIntOrFloatBitWidth(); + int64_t allOnInt = llvm::APInt::getAllOnes(bits).getSExtValue(); + return builder.createIntegerConstant(loc, type, allOnInt); + } + } + llvm_unreachable("Unknown Reduction Intrinsic"); + } + + static mlir::Value getIntrinsicOpInitValue( + mlir::Location loc, mlir::Type type, + Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp, + fir::FirOpBuilder &builder) { if (type.isa<mlir::FloatType>()) return builder.create<mlir::arith::ConstantOp>( loc, type, - builder.getFloatAttr( - type, (double)getOperationIdentity(reductionOpName, loc))); + builder.getFloatAttr(type, + (double)getOperationIdentity(intrinsicOp, loc))); if (type.isa<fir::LogicalType>()) { mlir::Value intConst = builder.create<mlir::arith::ConstantOp>( loc, builder.getI1Type(), builder.getIntegerAttr(builder.getI1Type(), - getOperationIdentity(reductionOpName, loc))); + getOperationIdentity(intrinsicOp, loc))); return builder.createConvert(loc, type, intConst); } return builder.create<mlir::arith::ConstantOp>( loc, type, - builder.getIntegerAttr(type, - getOperationIdentity(reductionOpName, loc))); + builder.getIntegerAttr(type, getOperationIdentity(intrinsicOp, loc))); } -} -template <typename FloatOp, typename IntegerOp> -static mlir::Value getReductionOperation(fir::FirOpBuilder &builder, - mlir::Type type, mlir::Location loc, - mlir::Value op1, mlir::Value op2) { - assert(type.isIntOrIndexOrFloat() && - "only integer and float types are currently supported"); - if (type.isIntOrIndex()) - return builder.create<IntegerOp>(loc, op1, op2); - return builder.create<FloatOp>(loc, op1, op2); -} + template <typename FloatOp, typename IntegerOp> + static mlir::Value getReductionOperation(fir::FirOpBuilder &builder, + mlir::Type type, mlir::Location loc, + mlir::Value op1, mlir::Value op2) { + assert(type.isIntOrIndexOrFloat() && + "only integer and float types are currently supported"); + if (type.isIntOrIndex()) + return builder.create<IntegerOp>(loc, op1, op2); + return builder.create<FloatOp>(loc, op1, op2); + } -static mlir::omp::ReductionDeclareOp -createMinimalReductionDecl(fir::FirOpBuilder &builder, - llvm::StringRef reductionOpName, mlir::Type type, - mlir::Location loc) { - mlir::ModuleOp module = builder.getModule(); - mlir::OpBuilder modBuilder(module.getBodyRegion()); + /// Creates an OpenMP reduction declaration and inserts it into the provided + /// symbol table. The declaration has a constant initializer with the neutral + /// value `initValue`, and the reduction combiner carried over from `reduce`. + /// TODO: Generalize this for non-integer types, add atomic region. + static mlir::omp::ReductionDeclareOp createReductionDecl( + fir::FirOpBuilder &builder, llvm::StringRef reductionOpName, + const Fortran::parser::ProcedureDesignator &procDesignator, + mlir::Type type, mlir::Location loc) { + mlir::OpBuilder::InsertionGuard guard(builder); + mlir::ModuleOp module = builder.getModule(); - mlir::omp::ReductionDeclareOp decl = - modBuilder.create<mlir::omp::ReductionDeclareOp>(loc, reductionOpName, - type); - builder.createBlock(&decl.getInitializerRegion(), - decl.getInitializerRegion().end(), {type}, {loc}); - builder.setInsertionPointToEnd(&decl.getInitializerRegion().back()); - mlir::Value init = getReductionInitValue(loc, type, reductionOpName, builder); - builder.create<mlir::omp::YieldOp>(loc, init); + auto decl = + module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName); + if (decl) + return decl; - builder.createBlock(&decl.getReductionRegion(), - decl.getReductionRegion().end(), {type, type}, - {loc, loc}); + mlir::OpBuilder modBuilder(module.getBodyRegion()); - return decl; -} + decl = modBuilder.create<mlir::omp::ReductionDeclareOp>( + loc, reductionOpName, type); + builder.createBlock(&decl.getInitializerRegion(), + decl.getInitializerRegion().end(), {type}, {loc}); + builder.setInsertionPointToEnd(&decl.getInitializerRegion().back()); + mlir::Value init = + getIntrinsicProcInitValue(loc, type, procDesignator, builder); + builder.create<mlir::omp::YieldOp>(loc, init); -/// Creates an OpenMP reduction declaration and inserts it into the provided -/// symbol table. The declaration has a constant initializer with the neutral -/// value `initValue`, and the reduction combiner carried over from `reduce`. -/// TODO: Generalize this for non-integer types, add atomic region. -static mlir::omp::ReductionDeclareOp -createReductionDecl(fir::FirOpBuilder &builder, llvm::StringRef reductionOpName, - const Fortran::parser::ProcedureDesignator &procDesignator, - mlir::Type type, mlir::Location loc) { - mlir::OpBuilder::InsertionGuard guard(builder); - mlir::ModuleOp module = builder.getModule(); + builder.createBlock(&decl.getReductionRegion(), + decl.getReductionRegion().end(), {type, type}, + {loc, loc}); - auto decl = - module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName); - if (decl) - return decl; + builder.setInsertionPointToEnd(&decl.getReductionRegion().back()); + mlir::Value op1 = decl.getReductionRegion().front().getArgument(0); + mlir::Value op2 = decl.getReductionRegion().front().getArgument(1); - decl = createMinimalReductionDecl(builder, reductionOpName, type, loc); - builder.setInsertionPointToEnd(&decl.getReductionRegion().back()); - mlir::Value op1 = decl.getReductionRegion().front().getArgument(0); - mlir::Value op2 = decl.getReductionRegion().front().getArgument(1); - - mlir::Value reductionOp; - if (const auto *name{ - Fortran::parser::Unwrap<Fortran::parser::Name>(procDesignator)}) { - if (name->source == "max") { + mlir::Value reductionOp; + switch (getReductionType(procDesignator)) { + case IntrinsicProc::MAX: reductionOp = getReductionOperation<mlir::arith::MaximumFOp, mlir::arith::MaxSIOp>( builder, type, loc, op1, op2); - } else if (name->source == "min") { + break; + case IntrinsicProc::MIN: reductionOp = getReductionOperation<mlir::arith::MinimumFOp, mlir::arith::MinSIOp>( builder, type, loc, op1, op2); - } else if (name->source == "ior") { + break; + case IntrinsicProc::IOR: assert((type.isIntOrIndex()) && "only integer is expected"); reductionOp = builder.create<mlir::arith::OrIOp>(loc, op1, op2); - } else if (name->source == "ieor") { + break; + case IntrinsicProc::IEOR: assert((type.isIntOrIndex()) && "only integer is expected"); reductionOp = builder.create<mlir::arith::XOrIOp>(loc, op1, op2); - } else if (name->source == "iand") { + break; + case IntrinsicProc::IAND: assert((type.isIntOrIndex()) && "only integer is expected"); reductionOp = builder.create<mlir::arith::AndIOp>(loc, op1, op2); - } else { + break; + } + + builder.create<mlir::omp::YieldOp>(loc, reductionOp); + return decl; + } + + /// Creates an OpenMP reduction declaration and inserts it into the provided + /// symbol table. The declaration has a constant initializer with the neutral + /// value `initValue`, and the reduction combiner carried over from `reduce`. + /// TODO: Generalize this for non-integer types, add atomic region. + static mlir::omp::ReductionDeclareOp createReductionDecl( + fir::FirOpBuilder &builder, llvm::StringRef reductionOpName, + Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp, + mlir::Type type, mlir::Location loc) { + mlir::OpBuilder::InsertionGuard guard(builder); + mlir::ModuleOp module = builder.getModule(); + + auto decl = + module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName); + if (decl) + return decl; + + mlir::OpBuilder modBuilder(module.getBodyRegion()); + + decl = modBuilder.create<mlir::omp::ReductionDeclareOp>( + loc, reductionOpName, type); + builder.createBlock(&decl.getInitializerRegion(), + decl.getInitializerRegion().end(), {type}, {loc}); + builder.setInsertionPointToEnd(&decl.getInitializerRegion().back()); + mlir::Value init = getIntrinsicOpInitValue(loc, type, intrinsicOp, builder); + builder.create<mlir::omp::YieldOp>(loc, init); + + builder.createBlock(&decl.getReductionRegion(), + decl.getReductionRegion().end(), {type, type}, + {loc, loc}); + + builder.setInsertionPointToEnd(&decl.getReductionRegion().back()); + mlir::Value op1 = decl.getReductionRegion().front().getArgument(0); + mlir::Value op2 = decl.getReductionRegion().front().getArgument(1); + + mlir::Value reductionOp; + switch (intrinsicOp) { + case Fortran::parser::DefinedOperator::IntrinsicOperator::Add: + reductionOp = + getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp>( + builder, type, loc, op1, op2); + break; + case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply: + reductionOp = + getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp>( + builder, type, loc, op1, op2); + break; + case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: { + mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); + mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); + + mlir::Value andiOp = + builder.create<mlir::arith::AndIOp>(loc, op1I1, op2I1); + + reductionOp = builder.createConvert(loc, type, andiOp); + break; + } + case Fortran::parser::DefinedOperator::IntrinsicOperator::OR: { + mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); + mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); + + mlir::Value oriOp = builder.create<mlir::arith::OrIOp>(loc, op1I1, op2I1); + + reductionOp = builder.createConvert(loc, type, oriOp); + break; + } + case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: { + mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); + mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); + + mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>( + loc, mlir::arith::CmpIPredicate::eq, op1I1, op2I1); + + reductionOp = builder.createConvert(loc, type, cmpiOp); + break; + } + case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV: { + mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); + mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); + + mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>( + loc, mlir::arith::CmpIPredicate::ne, op1I1, op2I1); + + reductionOp = builder.createConvert(loc, type, cmpiOp); + break; + } + default: TODO(loc, "Reduction of some intrinsic operators is not supported"); } - } - builder.create<mlir::omp::YieldOp>(loc, reductionOp); - return decl; -} - -/// Creates an OpenMP reduction declaration and inserts it into the provided -/// symbol table. The declaration has a constant initializer with the neutral -/// value `initValue`, and the reduction combiner carried over from `reduce`. -/// TODO: Generalize this for non-integer types, add atomic region. -static mlir::omp::ReductionDeclareOp createReductionDecl( - fir::FirOpBuilder &builder, llvm::StringRef reductionOpName, - Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp, - mlir::Type type, mlir::Location loc) { - mlir::OpBuilder::InsertionGuard guard(builder); - mlir::ModuleOp module = builder.getModule(); - - auto decl = - module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName); - if (decl) + builder.create<mlir::omp::YieldOp>(loc, reductionOp); return decl; - - decl = createMinimalReductionDecl(builder, reductionOpName, type, loc); - builder.setInsertionPointToEnd(&decl.getReductionRegion().back()); - mlir::Value op1 = decl.getReductionRegion().front().getArgument(0); - mlir::Value op2 = decl.getReductionRegion().front().getArgument(1); - - mlir::Value reductionOp; - switch (intrinsicOp) { - case Fortran::parser::DefinedOperator::IntrinsicOperator::Add: - reductionOp = - getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp>( - builder, type, loc, op1, op2); - break; - case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply: - reductionOp = - getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp>( - builder, type, loc, op1, op2); - break; - case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: { - mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); - mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); - - mlir::Value andiOp = builder.create<mlir::arith::AndIOp>(loc, op1I1, op2I1); - - reductionOp = builder.createConvert(loc, type, andiOp); - break; - } - case Fortran::parser::DefinedOperator::IntrinsicOperator::OR: { - mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); - mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); - - mlir::Value oriOp = builder.create<mlir::arith::OrIOp>(loc, op1I1, op2I1); - - reductionOp = builder.createConvert(loc, type, oriOp); - break; - } - case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: { - mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); - mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); - - mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>( - loc, mlir::arith::CmpIPredicate::eq, op1I1, op2I1); - - reductionOp = builder.createConvert(loc, type, cmpiOp); - break; - } - case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV: { - mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); - mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); - - mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>( - loc, mlir::arith::CmpIPredicate::ne, op1I1, op2I1); - - reductionOp = builder.createConvert(loc, type, cmpiOp); - break; - } - default: - TODO(loc, "Reduction of some intrinsic operators is not supported"); } - builder.create<mlir::omp::YieldOp>(loc, reductionOp); - return decl; -} + /// Creates a reduction declaration and associates it with an OpenMP block + /// directive. + static void addReductionDecl( + mlir::Location currentLocation, + Fortran::lower::AbstractConverter &converter, + const Fortran::parser::OmpReductionClause &reduction, + llvm::SmallVectorImpl<mlir::Value> &reductionVars, + llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols) { + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + mlir::omp::ReductionDeclareOp decl; + const auto &redOperator{ + std::get<Fortran::parser::OmpReductionOperator>(reduction.t)}; + const auto &objectList{ + std::get<Fortran::parser::OmpObjectList>(reduction.t)}; + if (const auto &redDefinedOp = + std::get_if<Fortran::parser::DefinedOperator>(&redOperator.u)) { + const auto &intrinsicOp{ + std::get<Fortran::parser::DefinedOperator::IntrinsicOperator>( + redDefinedOp->u)}; + switch (intrinsicOp) { + case Fortran::parser::DefinedOperator::IntrinsicOperator::Add: + case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply: + case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: + case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: + case Fortran::parser::DefinedOperator::IntrinsicOperator::OR: + case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV: + break; + + default: + TODO(currentLocation, + "Reduction of some intrinsic operators is not supported"); + break; + } + for (const Fortran::parser::OmpObject &ompObject : objectList.v) { + if (const auto *name{ + Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) { + if (const Fortran::semantics::Symbol * symbol{name->symbol}) { + mlir::Value symVal = converter.getSymbolAddress(*symbol); + if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>()) + symVal = declOp.getBase(); + mlir::Type redType = + symVal.getType().cast<fir::ReferenceType>().getEleTy(); + reductionVars.push_back(symVal); + if (redType.isa<fir::LogicalType>()) + decl = createReductionDecl( + firOpBuilder, + getReductionName(intrinsicOp, firOpBuilder.getI1Type()), + intrinsicOp, redType, currentLocation); + else if (redType.isIntOrIndexOrFloat()) { + decl = createReductionDecl(firOpBuilder, + getReductionName(intrinsicOp, redType), + intrinsicOp, redType, currentLocation); + } else { + TODO(currentLocation, "Reduction of some types is not supported"); + } + reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get( + firOpBuilder.getContext(), decl.getSymName())); + } + } + } + } else if (const auto *reductionIntrinsic = + std::get_if<Fortran::parser::ProcedureDesignator>( + &redOperator.u)) { + if (ReductionProcessor::supportedIntrinsicProcReduction( + *reductionIntrinsic)) { + for (const Fortran::parser::OmpObject &ompObject : objectList.v) { + if (const auto *name{ + Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) { + if (const Fortran::semantics::Symbol * symbol{name->symbol}) { + mlir::Value symVal = converter.getSymbolAddress(*symbol); + if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>()) + symVal = declOp.getBase(); + mlir::Type redType = + symVal.getType().cast<fir::ReferenceType>().getEleTy(); + reductionVars.push_back(symVal); + assert(redType.isIntOrIndexOrFloat() && + "Unsupported reduction type"); + decl = createReductionDecl( + firOpBuilder, + getReductionName(getRealName(*reductionIntrinsic).ToString(), + redType), + *reductionIntrinsic, redType, currentLocation); + reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get( + firOpBuilder.getContext(), decl.getSymName())); + } + } + } + } + } + } +}; static mlir::omp::ScheduleModifier translateScheduleModifier(const Fortran::parser::OmpScheduleModifierType &m) { @@ -1176,101 +1338,6 @@ ifVal); } -/// Creates a reduction declaration and associates it with an OpenMP block -/// directive. -static void -addReductionDecl(mlir::Location currentLocation, - Fortran::lower::AbstractConverter &converter, - const Fortran::parser::OmpReductionClause &reduction, - llvm::SmallVectorImpl<mlir::Value> &reductionVars, - llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols) { - fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - mlir::omp::ReductionDeclareOp decl; - const auto &redOperator{ - std::get<Fortran::parser::OmpReductionOperator>(reduction.t)}; - const auto &objectList{std::get<Fortran::parser::OmpObjectList>(reduction.t)}; - if (const auto &redDefinedOp = - std::get_if<Fortran::parser::DefinedOperator>(&redOperator.u)) { - const auto &intrinsicOp{ - std::get<Fortran::parser::DefinedOperator::IntrinsicOperator>( - redDefinedOp->u)}; - switch (intrinsicOp) { - case Fortran::parser::DefinedOperator::IntrinsicOperator::Add: - case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply: - case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: - case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: - case Fortran::parser::DefinedOperator::IntrinsicOperator::OR: - case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV: - break; - - default: - TODO(currentLocation, - "Reduction of some intrinsic operators is not supported"); - break; - } - for (const Fortran::parser::OmpObject &ompObject : objectList.v) { - if (const auto *name{ - Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) { - if (const Fortran::semantics::Symbol * symbol{name->symbol}) { - mlir::Value symVal = converter.getSymbolAddress(*symbol); - if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>()) - symVal = declOp.getBase(); - mlir::Type redType = - symVal.getType().cast<fir::ReferenceType>().getEleTy(); - reductionVars.push_back(symVal); - if (redType.isa<fir::LogicalType>()) - decl = createReductionDecl( - firOpBuilder, - getReductionName(intrinsicOp, firOpBuilder.getI1Type()), - intrinsicOp, redType, currentLocation); - else if (redType.isIntOrIndexOrFloat()) { - decl = createReductionDecl(firOpBuilder, - getReductionName(intrinsicOp, redType), - intrinsicOp, redType, currentLocation); - } else { - TODO(currentLocation, "Reduction of some types is not supported"); - } - reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get( - firOpBuilder.getContext(), decl.getSymName())); - } - } - } - } else if (const auto *reductionIntrinsic = - std::get_if<Fortran::parser::ProcedureDesignator>( - &redOperator.u)) { - if (const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>( - reductionIntrinsic)}) { - if ((name->source != "max") && (name->source != "min") && - (name->source != "ior") && (name->source != "ieor") && - (name->source != "iand")) { - TODO(currentLocation, - "Reduction of intrinsic procedures is not supported"); - } - std::string intrinsicOp = name->ToString(); - for (const Fortran::parser::OmpObject &ompObject : objectList.v) { - if (const auto *name{ - Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) { - if (const Fortran::semantics::Symbol * symbol{name->symbol}) { - mlir::Value symVal = converter.getSymbolAddress(*symbol); - if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>()) - symVal = declOp.getBase(); - mlir::Type redType = - symVal.getType().cast<fir::ReferenceType>().getEleTy(); - reductionVars.push_back(symVal); - assert(redType.isIntOrIndexOrFloat() && - "Unsupported reduction type"); - decl = createReductionDecl( - firOpBuilder, getReductionName(intrinsicOp, redType), - *reductionIntrinsic, redType, currentLocation); - reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get( - firOpBuilder.getContext(), decl.getSymName())); - } - } - } - } - } -} - static void addUseDeviceClause(Fortran::lower::AbstractConverter &converter, const Fortran::parser::OmpObjectList &useDeviceClause, @@ -1864,8 +1931,9 @@ return findRepeatableClause<ClauseTy::Reduction>( [&](const ClauseTy::Reduction *reductionClause, const Fortran::parser::CharBlock &) { - addReductionDecl(currentLocation, converter, reductionClause->v, - reductionVars, reductionDeclSymbols); + ReductionProcessor rp; + rp.addReductionDecl(currentLocation, converter, reductionClause->v, + reductionVars, reductionDeclSymbols); }); } @@ -3891,48 +3959,50 @@ } else if (const auto *reductionIntrinsic = std::get_if<Fortran::parser::ProcedureDesignator>( &redOperator.u)) { - if (const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>( - reductionIntrinsic)}) { - std::string redName = name->ToString(); - if ((name->source != "max") && (name->source != "min") && - (name->source != "ior") && (name->source != "ieor") && - (name->source != "iand")) { - continue; - } - for (const Fortran::parser::OmpObject &ompObject : objectList.v) { - if (const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>( - ompObject)}) { - if (const Fortran::semantics::Symbol * symbol{name->symbol}) { - mlir::Value reductionVal = converter.getSymbolAddress(*symbol); - if (auto declOp = - reductionVal.getDefiningOp<hlfir::DeclareOp>()) - reductionVal = declOp.getBase(); - for (const mlir::OpOperand &reductionValUse : - reductionVal.getUses()) { - if (auto loadOp = mlir::dyn_cast<fir::LoadOp>( - reductionValUse.getOwner())) { - mlir::Value loadVal = loadOp.getRes(); - // Max is lowered as a compare -> select. - // Match the pattern here. - mlir::Operation *reductionOp = - findReductionChain(loadVal, &reductionVal); - if (reductionOp == nullptr) - continue; + if (!ReductionProcessor::supportedIntrinsicProcReduction( + *reductionIntrinsic)) + continue; + ReductionProcessor::IntrinsicProc redIntrinsicProc = + ReductionProcessor::getReductionType(*reductionIntrinsic); + for (const Fortran::parser::OmpObject &ompObject : objectList.v) { + if (const auto *name{ + Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) { + if (const Fortran::semantics::Symbol * symbol{name->symbol}) { + mlir::Value reductionVal = converter.getSymbolAddress(*symbol); + if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>()) + reductionVal = declOp.getBase(); + for (const mlir::OpOperand &reductionValUse : + reductionVal.getUses()) { + if (auto loadOp = mlir::dyn_cast<fir::LoadOp>( + reductionValUse.getOwner())) { + mlir::Value loadVal = loadOp.getRes(); + // Max is lowered as a compare -> select. + // Match the pattern here. + mlir::Operation *reductionOp = + findReductionChain(loadVal, &reductionVal); + if (reductionOp == nullptr) + continue; - if (redName == "max" || redName == "min") { - assert(mlir::isa<mlir::arith::SelectOp>(reductionOp) && - "Selection Op not found in reduction intrinsic"); - mlir::Operation *compareOp = - getCompareFromReductionOp(reductionOp, loadVal); - updateReduction(compareOp, firOpBuilder, loadVal, - reductionVal); - } - if (redName == "ior" || redName == "ieor" || - redName == "iand") { + if (redIntrinsicProc == + ReductionProcessor::IntrinsicProc::MAX || + redIntrinsicProc == + ReductionProcessor::IntrinsicProc::MIN) { + assert(mlir::isa<mlir::arith::SelectOp>(reductionOp) && + "Selection Op not found in reduction intrinsic"); + mlir::Operation *compareOp = + getCompareFromReductionOp(reductionOp, loadVal); + updateReduction(compareOp, firOpBuilder, loadVal, + reductionVal); + } + if (redIntrinsicProc == + ReductionProcessor::IntrinsicProc::IOR || + redIntrinsicProc == + ReductionProcessor::IntrinsicProc::IEOR || + redIntrinsicProc == + ReductionProcessor::IntrinsicProc::IAND) { - updateReduction(reductionOp, firOpBuilder, loadVal, - reductionVal); - } + updateReduction(reductionOp, firOpBuilder, loadVal, + reductionVal); } } }