blob: 8dd1a054af883a16c189ed8b42bdbd6a8fe9cd40 [file] [log] [blame]
//===- TensorSpec.cpp - tensor type abstraction ---------------------------===//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Implementation file for the abstraction of a tensor type, and JSON loading
// utils.
#include "llvm/ADT/STLExtras.h"
#include "llvm/Config/config.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Analysis/TensorSpec.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/JSON.h"
#include "llvm/Support/ManagedStatic.h"
#include "llvm/Support/raw_ostream.h"
#include <array>
#include <cassert>
#include <numeric>
using namespace llvm;
namespace llvm {
template <> TensorType TensorSpec::getDataType<T>() { return TensorType::E; }
static std::array<std::string, static_cast<size_t>(TensorType::Total)>
StringRef toString(TensorType TT) {
return TensorTypeNames[static_cast<size_t>(TT)];
void TensorSpec::toJSON(json::OStream &OS) const {
OS.object([&]() {
OS.attribute("name", name());
OS.attribute("type", toString(type()));
OS.attribute("port", port());
OS.attributeArray("shape", [&]() {
for (size_t D : shape())
TensorSpec::TensorSpec(const std::string &Name, int Port, TensorType Type,
size_t ElementSize, const std::vector<int64_t> &Shape)
: Name(Name), Port(Port), Type(Type), Shape(Shape),
ElementCount(std::accumulate(Shape.begin(), Shape.end(), 1,
ElementSize(ElementSize) {}
std::optional<TensorSpec> getTensorSpecFromJSON(LLVMContext &Ctx,
const json::Value &Value) {
auto EmitError =
[&](const llvm::Twine &Message) -> std::optional<TensorSpec> {
std::string S;
llvm::raw_string_ostream OS(S);
OS << Value;
Ctx.emitError("Unable to parse JSON Value as spec (" + Message + "): " + S);
return std::nullopt;
// FIXME: accept a Path as a parameter, and use it for error reporting.
json::Path::Root Root("tensor_spec");
json::ObjectMapper Mapper(Value, Root);
if (!Mapper)
return EmitError("Value is not a dict");
std::string TensorName;
int TensorPort = -1;
std::string TensorType;
std::vector<int64_t> TensorShape;
if (!<std::string>("name", TensorName))
return EmitError("'name' property not present or not a string");
if (!<std::string>("type", TensorType))
return EmitError("'type' property not present or not a string");
if (!<int>("port", TensorPort))
return EmitError("'port' property not present or not an int");
if (!<std::vector<int64_t>>("shape", TensorShape))
return EmitError("'shape' property not present or not an int array");
#define PARSE_TYPE(T, E) \
if (TensorType == #T) \
return TensorSpec::createSpec<T>(TensorName, TensorShape, TensorPort);
return std::nullopt;
std::string tensorValueToString(const char *Buffer, const TensorSpec &Spec) {
switch (Spec.type()) {
#define _IMR_DBG_PRINTER(T, N) \
case TensorType::N: { \
const T *TypedBuff = reinterpret_cast<const T *>(Buffer); \
auto R = llvm::make_range(TypedBuff, TypedBuff + Spec.getElementCount()); \
return llvm::join( \
llvm::map_range(R, [](T V) { return std::to_string(V); }), ","); \
case TensorType::Total:
case TensorType::Invalid:
llvm_unreachable("invalid tensor type");
// To appease warnings about not all control paths returning a value.
return "";
} // namespace llvm