| //===-- runtime/dot-product.cpp -------------------------------------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "terminator.h" |
| #include "tools.h" |
| #include "flang/Runtime/cpp-type.h" |
| #include "flang/Runtime/descriptor.h" |
| #include "flang/Runtime/reduction.h" |
| #include <cinttypes> |
| |
| namespace Fortran::runtime { |
| |
| // Beware: DOT_PRODUCT of COMPLEX data uses the complex conjugate of the first |
| // argument; MATMUL does not. |
| |
| // General accumulator for any type and stride; this is not used for |
| // contiguous numeric vectors. |
| template <TypeCategory RCAT, int RKIND, typename XT, typename YT> |
| class Accumulator { |
| public: |
| using Result = AccumulationType<RCAT, RKIND>; |
| Accumulator(const Descriptor &x, const Descriptor &y) : x_{x}, y_{y} {} |
| void AccumulateIndexed(SubscriptValue xAt, SubscriptValue yAt) { |
| if constexpr (RCAT == TypeCategory::Logical) { |
| sum_ = sum_ || |
| (IsLogicalElementTrue(x_, &xAt) && IsLogicalElementTrue(y_, &yAt)); |
| } else { |
| const XT &xElement{*x_.Element<XT>(&xAt)}; |
| const YT &yElement{*y_.Element<YT>(&yAt)}; |
| if constexpr (RCAT == TypeCategory::Complex) { |
| sum_ += std::conj(static_cast<Result>(xElement)) * |
| static_cast<Result>(yElement); |
| } else { |
| sum_ += static_cast<Result>(xElement) * static_cast<Result>(yElement); |
| } |
| } |
| } |
| Result GetResult() const { return sum_; } |
| |
| private: |
| const Descriptor &x_, &y_; |
| Result sum_{}; |
| }; |
| |
| template <TypeCategory RCAT, int RKIND, typename XT, typename YT> |
| static inline CppTypeFor<RCAT, RKIND> DoDotProduct( |
| const Descriptor &x, const Descriptor &y, Terminator &terminator) { |
| using Result = CppTypeFor<RCAT, RKIND>; |
| RUNTIME_CHECK(terminator, x.rank() == 1 && y.rank() == 1); |
| SubscriptValue n{x.GetDimension(0).Extent()}; |
| if (SubscriptValue yN{y.GetDimension(0).Extent()}; yN != n) { |
| terminator.Crash( |
| "DOT_PRODUCT: SIZE(VECTOR_A) is %jd but SIZE(VECTOR_B) is %jd", |
| static_cast<std::intmax_t>(n), static_cast<std::intmax_t>(yN)); |
| } |
| if constexpr (RCAT != TypeCategory::Logical) { |
| if (x.GetDimension(0).ByteStride() == sizeof(XT) && |
| y.GetDimension(0).ByteStride() == sizeof(YT)) { |
| // Contiguous numeric vectors |
| if constexpr (std::is_same_v<XT, YT>) { |
| // Contiguous homogeneous numeric vectors |
| if constexpr (std::is_same_v<XT, float>) { |
| // TODO: call BLAS-1 SDOT or SDSDOT |
| } else if constexpr (std::is_same_v<XT, double>) { |
| // TODO: call BLAS-1 DDOT |
| } else if constexpr (std::is_same_v<XT, std::complex<float>>) { |
| // TODO: call BLAS-1 CDOTC |
| } else if constexpr (std::is_same_v<XT, std::complex<double>>) { |
| // TODO: call BLAS-1 ZDOTC |
| } |
| } |
| XT *xp{x.OffsetElement<XT>(0)}; |
| YT *yp{y.OffsetElement<YT>(0)}; |
| using AccumType = AccumulationType<RCAT, RKIND>; |
| AccumType accum{}; |
| if constexpr (RCAT == TypeCategory::Complex) { |
| for (SubscriptValue j{0}; j < n; ++j) { |
| accum += std::conj(static_cast<AccumType>(*xp++)) * |
| static_cast<AccumType>(*yp++); |
| } |
| } else { |
| for (SubscriptValue j{0}; j < n; ++j) { |
| accum += |
| static_cast<AccumType>(*xp++) * static_cast<AccumType>(*yp++); |
| } |
| } |
| return static_cast<Result>(accum); |
| } |
| } |
| // Non-contiguous, heterogeneous, & LOGICAL cases |
| SubscriptValue xAt{x.GetDimension(0).LowerBound()}; |
| SubscriptValue yAt{y.GetDimension(0).LowerBound()}; |
| Accumulator<RCAT, RKIND, XT, YT> accumulator{x, y}; |
| for (SubscriptValue j{0}; j < n; ++j) { |
| accumulator.AccumulateIndexed(xAt++, yAt++); |
| } |
| return static_cast<Result>(accumulator.GetResult()); |
| } |
| |
| template <TypeCategory RCAT, int RKIND> struct DotProduct { |
| using Result = CppTypeFor<RCAT, RKIND>; |
| template <TypeCategory XCAT, int XKIND> struct DP1 { |
| template <TypeCategory YCAT, int YKIND> struct DP2 { |
| Result operator()(const Descriptor &x, const Descriptor &y, |
| Terminator &terminator) const { |
| if constexpr (constexpr auto resultType{ |
| GetResultType(XCAT, XKIND, YCAT, YKIND)}) { |
| if constexpr (resultType->first == RCAT && |
| (resultType->second <= RKIND || RCAT == TypeCategory::Logical)) { |
| return DoDotProduct<RCAT, RKIND, CppTypeFor<XCAT, XKIND>, |
| CppTypeFor<YCAT, YKIND>>(x, y, terminator); |
| } |
| } |
| terminator.Crash( |
| "DOT_PRODUCT(%d(%d)): bad operand types (%d(%d), %d(%d))", |
| static_cast<int>(RCAT), RKIND, static_cast<int>(XCAT), XKIND, |
| static_cast<int>(YCAT), YKIND); |
| } |
| }; |
| Result operator()(const Descriptor &x, const Descriptor &y, |
| Terminator &terminator, TypeCategory yCat, int yKind) const { |
| return ApplyType<DP2, Result>(yCat, yKind, terminator, x, y, terminator); |
| } |
| }; |
| Result operator()(const Descriptor &x, const Descriptor &y, |
| const char *source, int line) const { |
| Terminator terminator{source, line}; |
| if (RCAT != TypeCategory::Logical && x.type() == y.type()) { |
| // No conversions needed, operands and result have same known type |
| return typename DP1<RCAT, RKIND>::template DP2<RCAT, RKIND>{}( |
| x, y, terminator); |
| } else { |
| auto xCatKind{x.type().GetCategoryAndKind()}; |
| auto yCatKind{y.type().GetCategoryAndKind()}; |
| RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value()); |
| return ApplyType<DP1, Result>(xCatKind->first, xCatKind->second, |
| terminator, x, y, terminator, yCatKind->first, yCatKind->second); |
| } |
| } |
| }; |
| |
| extern "C" { |
| std::int8_t RTNAME(DotProductInteger1)( |
| const Descriptor &x, const Descriptor &y, const char *source, int line) { |
| return DotProduct<TypeCategory::Integer, 1>{}(x, y, source, line); |
| } |
| std::int16_t RTNAME(DotProductInteger2)( |
| const Descriptor &x, const Descriptor &y, const char *source, int line) { |
| return DotProduct<TypeCategory::Integer, 2>{}(x, y, source, line); |
| } |
| std::int32_t RTNAME(DotProductInteger4)( |
| const Descriptor &x, const Descriptor &y, const char *source, int line) { |
| return DotProduct<TypeCategory::Integer, 4>{}(x, y, source, line); |
| } |
| std::int64_t RTNAME(DotProductInteger8)( |
| const Descriptor &x, const Descriptor &y, const char *source, int line) { |
| return DotProduct<TypeCategory::Integer, 8>{}(x, y, source, line); |
| } |
| #ifdef __SIZEOF_INT128__ |
| common::int128_t RTNAME(DotProductInteger16)( |
| const Descriptor &x, const Descriptor &y, const char *source, int line) { |
| return DotProduct<TypeCategory::Integer, 16>{}(x, y, source, line); |
| } |
| #endif |
| |
| // TODO: REAL/COMPLEX(2 & 3) |
| // Intermediate results and operations are at least 64 bits |
| float RTNAME(DotProductReal4)( |
| const Descriptor &x, const Descriptor &y, const char *source, int line) { |
| return DotProduct<TypeCategory::Real, 4>{}(x, y, source, line); |
| } |
| double RTNAME(DotProductReal8)( |
| const Descriptor &x, const Descriptor &y, const char *source, int line) { |
| return DotProduct<TypeCategory::Real, 8>{}(x, y, source, line); |
| } |
| #if LONG_DOUBLE == 80 |
| long double RTNAME(DotProductReal10)( |
| const Descriptor &x, const Descriptor &y, const char *source, int line) { |
| return DotProduct<TypeCategory::Real, 10>{}(x, y, source, line); |
| } |
| #elif LONG_DOUBLE == 128 |
| long double RTNAME(DotProductReal16)( |
| const Descriptor &x, const Descriptor &y, const char *source, int line) { |
| return DotProduct<TypeCategory::Real, 16>{}(x, y, source, line); |
| } |
| #endif |
| |
| void RTNAME(CppDotProductComplex4)(std::complex<float> &result, |
| const Descriptor &x, const Descriptor &y, const char *source, int line) { |
| auto z{DotProduct<TypeCategory::Complex, 4>{}(x, y, source, line)}; |
| result = std::complex<float>{ |
| static_cast<float>(z.real()), static_cast<float>(z.imag())}; |
| } |
| void RTNAME(CppDotProductComplex8)(std::complex<double> &result, |
| const Descriptor &x, const Descriptor &y, const char *source, int line) { |
| result = DotProduct<TypeCategory::Complex, 8>{}(x, y, source, line); |
| } |
| #if LONG_DOUBLE == 80 |
| void RTNAME(CppDotProductComplex10)(std::complex<long double> &result, |
| const Descriptor &x, const Descriptor &y, const char *source, int line) { |
| result = DotProduct<TypeCategory::Complex, 10>{}(x, y, source, line); |
| } |
| #elif LONG_DOUBLE == 128 |
| void RTNAME(CppDotProductComplex16)(std::complex<long double> &result, |
| const Descriptor &x, const Descriptor &y, const char *source, int line) { |
| result = DotProduct<TypeCategory::Complex, 16>{}(x, y, source, line); |
| } |
| #endif |
| |
| bool RTNAME(DotProductLogical)( |
| const Descriptor &x, const Descriptor &y, const char *source, int line) { |
| return DotProduct<TypeCategory::Logical, 1>{}(x, y, source, line); |
| } |
| } // extern "C" |
| } // namespace Fortran::runtime |