blob: 5e1e1139aba96b3b244767d5b69b374de563cf2c [file] [log] [blame] [edit]
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
///
/// \file
/// This file contains the definition of the RangeBasedGenerator class, a base
/// class for input generators that operate on a sequence of ranges.
///
//===----------------------------------------------------------------------===//
#ifndef MATHTEST_RANGEBASEDGENERATOR_HPP
#define MATHTEST_RANGEBASEDGENERATOR_HPP
#include "mathtest/IndexedRange.hpp"
#include "mathtest/InputGenerator.hpp"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/Support/Parallel.h"
#include <algorithm>
#include <array>
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <tuple>
namespace mathtest {
template <typename Derived, typename... InTypes>
class [[nodiscard]] RangeBasedGenerator : public InputGenerator<InTypes...> {
public:
void reset() noexcept override { NextFlatIndex = 0; }
[[nodiscard]] std::size_t
fill(llvm::MutableArrayRef<InTypes>... Buffers) noexcept override {
const std::array<std::size_t, NumInputs> BufferSizes = {Buffers.size()...};
const std::size_t BufferSize = BufferSizes[0];
assert((BufferSize != 0) && "Buffer size cannot be zero");
assert(std::all_of(BufferSizes.begin(), BufferSizes.end(),
[&](std::size_t Size) { return Size == BufferSize; }) &&
"All input buffers must have the same size");
if (NextFlatIndex >= Size)
return 0;
const auto BatchSize = std::min<uint64_t>(BufferSize, Size - NextFlatIndex);
const auto CurrentFlatIndex = NextFlatIndex;
NextFlatIndex += BatchSize;
auto BufferPtrsTuple = std::make_tuple(Buffers.data()...);
llvm::parallelFor(0, BatchSize, [&](std::size_t Offset) {
static_cast<Derived *>(this)->writeInputs(CurrentFlatIndex, Offset,
BufferPtrsTuple);
});
return static_cast<std::size_t>(BatchSize);
}
protected:
using RangesTupleType = std::tuple<IndexedRange<InTypes>...>;
static constexpr std::size_t NumInputs = sizeof...(InTypes);
static_assert(NumInputs > 0, "The number of inputs must be at least 1");
explicit constexpr RangeBasedGenerator(
const IndexedRange<InTypes> &...Ranges) noexcept
: RangesTuple(Ranges...) {}
explicit constexpr RangeBasedGenerator(
uint64_t Size, const IndexedRange<InTypes> &...Ranges) noexcept
: RangesTuple(Ranges...), Size(Size) {}
RangesTupleType RangesTuple;
uint64_t Size = 0;
private:
uint64_t NextFlatIndex = 0;
};
} // namespace mathtest
#endif // MATHTEST_RANGEBASEDGENERATOR_HPP