blob: 1ae839d5f40e7e5f1e27dcc86a90049c40e2e3d1 [file] [log] [blame]
//===- BytecodeImplementation.h - MLIR Bytecode Implementation --*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This header defines various interfaces and utilities necessary for dialects
// to hook into bytecode serialization.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_BYTECODE_BYTECODEIMPLEMENTATION_H
#define MLIR_BYTECODE_BYTECODEIMPLEMENTATION_H
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectInterface.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/Twine.h"
namespace mlir {
//===----------------------------------------------------------------------===//
// DialectBytecodeReader
//===----------------------------------------------------------------------===//
/// This class defines a virtual interface for reading a bytecode stream,
/// providing hooks into the bytecode reader. As such, this class should only be
/// derived and defined by the main bytecode reader, users (i.e. dialects)
/// should generally only interact with this class via the
/// BytecodeDialectInterface below.
class DialectBytecodeReader {
public:
virtual ~DialectBytecodeReader() = default;
/// Emit an error to the reader.
virtual InFlightDiagnostic emitError(const Twine &msg = {}) = 0;
/// Read out a list of elements, invoking the provided callback for each
/// element. The callback function may be in any of the following forms:
/// * LogicalResult(T &)
/// * FailureOr<T>()
template <typename T, typename CallbackFn>
LogicalResult readList(SmallVectorImpl<T> &result, CallbackFn &&callback) {
uint64_t size;
if (failed(readVarInt(size)))
return failure();
result.reserve(size);
for (uint64_t i = 0; i < size; ++i) {
// Check if the callback uses FailureOr, or populates the result by
// reference.
if constexpr (llvm::function_traits<std::decay_t<CallbackFn>>::num_args) {
T element = {};
if (failed(callback(element)))
return failure();
result.emplace_back(std::move(element));
} else {
FailureOr<T> element = callback();
if (failed(element))
return failure();
result.emplace_back(std::move(*element));
}
}
return success();
}
//===--------------------------------------------------------------------===//
// IR
//===--------------------------------------------------------------------===//
/// Read a reference to the given attribute.
virtual LogicalResult readAttribute(Attribute &result) = 0;
template <typename T>
LogicalResult readAttributes(SmallVectorImpl<T> &attrs) {
return readList(attrs, [this](T &attr) { return readAttribute(attr); });
}
template <typename T>
LogicalResult readAttribute(T &result) {
Attribute baseResult;
if (failed(readAttribute(baseResult)))
return failure();
if ((result = baseResult.dyn_cast<T>()))
return success();
return emitError() << "expected " << llvm::getTypeName<T>()
<< ", but got: " << baseResult;
}
/// Read a reference to the given type.
virtual LogicalResult readType(Type &result) = 0;
template <typename T>
LogicalResult readTypes(SmallVectorImpl<T> &types) {
return readList(types, [this](T &type) { return readType(type); });
}
template <typename T>
LogicalResult readType(T &result) {
Type baseResult;
if (failed(readType(baseResult)))
return failure();
if ((result = baseResult.dyn_cast<T>()))
return success();
return emitError() << "expected " << llvm::getTypeName<T>()
<< ", but got: " << baseResult;
}
/// Read a handle to a dialect resource.
template <typename ResourceT>
FailureOr<ResourceT> readResourceHandle() {
FailureOr<AsmDialectResourceHandle> handle = readResourceHandle();
if (failed(handle))
return failure();
if (auto *result = dyn_cast<ResourceT>(&*handle))
return std::move(*result);
return emitError() << "provided resource handle differs from the "
"expected resource type";
}
//===--------------------------------------------------------------------===//
// Primitives
//===--------------------------------------------------------------------===//
/// Read a variable width integer.
virtual LogicalResult readVarInt(uint64_t &result) = 0;
/// Read a signed variable width integer.
virtual LogicalResult readSignedVarInt(int64_t &result) = 0;
LogicalResult readSignedVarInts(SmallVectorImpl<int64_t> &result) {
return readList(result,
[this](int64_t &value) { return readSignedVarInt(value); });
}
/// Read an APInt that is known to have been encoded with the given width.
virtual FailureOr<APInt> readAPIntWithKnownWidth(unsigned bitWidth) = 0;
/// Read an APFloat that is known to have been encoded with the given
/// semantics.
virtual FailureOr<APFloat>
readAPFloatWithKnownSemantics(const llvm::fltSemantics &semantics) = 0;
/// Read a string from the bytecode.
virtual LogicalResult readString(StringRef &result) = 0;
/// Read a blob from the bytecode.
virtual LogicalResult readBlob(ArrayRef<char> &result) = 0;
private:
/// Read a handle to a dialect resource.
virtual FailureOr<AsmDialectResourceHandle> readResourceHandle() = 0;
};
//===----------------------------------------------------------------------===//
// DialectBytecodeWriter
//===----------------------------------------------------------------------===//
/// This class defines a virtual interface for writing to a bytecode stream,
/// providing hooks into the bytecode writer. As such, this class should only be
/// derived and defined by the main bytecode writer, users (i.e. dialects)
/// should generally only interact with this class via the
/// BytecodeDialectInterface below.
class DialectBytecodeWriter {
public:
virtual ~DialectBytecodeWriter() = default;
//===--------------------------------------------------------------------===//
// IR
//===--------------------------------------------------------------------===//
/// Write out a list of elements, invoking the provided callback for each
/// element.
template <typename RangeT, typename CallbackFn>
void writeList(RangeT &&range, CallbackFn &&callback) {
writeVarInt(llvm::size(range));
for (auto &element : range)
callback(element);
}
/// Write a reference to the given attribute.
virtual void writeAttribute(Attribute attr) = 0;
template <typename T>
void writeAttributes(ArrayRef<T> attrs) {
writeList(attrs, [this](T attr) { writeAttribute(attr); });
}
/// Write a reference to the given type.
virtual void writeType(Type type) = 0;
template <typename T>
void writeTypes(ArrayRef<T> types) {
writeList(types, [this](T type) { writeType(type); });
}
/// Write the given handle to a dialect resource.
virtual void
writeResourceHandle(const AsmDialectResourceHandle &resource) = 0;
//===--------------------------------------------------------------------===//
// Primitives
//===--------------------------------------------------------------------===//
/// Write a variable width integer to the output stream. This should be the
/// preferred method for emitting integers whenever possible.
virtual void writeVarInt(uint64_t value) = 0;
/// Write a signed variable width integer to the output stream. This should be
/// the preferred method for emitting signed integers whenever possible.
virtual void writeSignedVarInt(int64_t value) = 0;
void writeSignedVarInts(ArrayRef<int64_t> value) {
writeList(value, [this](int64_t value) { writeSignedVarInt(value); });
}
/// Write an APInt to the bytecode stream whose bitwidth will be known
/// externally at read time. This method is useful for encoding APInt values
/// when the width is known via external means, such as via a type. This
/// method should generally only be invoked if you need an APInt, otherwise
/// use the varint methods above. APInt values are generally encoded using
/// zigzag encoding, to enable more efficient encodings for negative values.
virtual void writeAPIntWithKnownWidth(const APInt &value) = 0;
/// Write an APFloat to the bytecode stream whose semantics will be known
/// externally at read time. This method is useful for encoding APFloat values
/// when the semantics are known via external means, such as via a type.
virtual void writeAPFloatWithKnownSemantics(const APFloat &value) = 0;
/// Write a string to the bytecode, which is owned by the caller and is
/// guaranteed to not die before the end of the bytecode process. This should
/// only be called if such a guarantee can be made, such as when the string is
/// owned by an attribute or type.
virtual void writeOwnedString(StringRef str) = 0;
/// Write a blob to the bytecode, which is owned by the caller and is
/// guaranteed to not die before the end of the bytecode process. The blob is
/// written as-is, with no additional compression or compaction.
virtual void writeOwnedBlob(ArrayRef<char> blob) = 0;
};
//===----------------------------------------------------------------------===//
// BytecodeDialectInterface
//===----------------------------------------------------------------------===//
class BytecodeDialectInterface
: public DialectInterface::Base<BytecodeDialectInterface> {
public:
using Base::Base;
//===--------------------------------------------------------------------===//
// Reading
//===--------------------------------------------------------------------===//
/// Read an attribute belonging to this dialect from the given reader. This
/// method should return null in the case of failure.
virtual Attribute readAttribute(DialectBytecodeReader &reader) const {
reader.emitError() << "dialect " << getDialect()->getNamespace()
<< " does not support reading attributes from bytecode";
return Attribute();
}
/// Read a type belonging to this dialect from the given reader. This method
/// should return null in the case of failure.
virtual Type readType(DialectBytecodeReader &reader) const {
reader.emitError() << "dialect " << getDialect()->getNamespace()
<< " does not support reading types from bytecode";
return Type();
}
//===--------------------------------------------------------------------===//
// Writing
//===--------------------------------------------------------------------===//
/// Write the given attribute, which belongs to this dialect, to the given
/// writer. This method may return failure to indicate that the given
/// attribute could not be encoded, in which case the textual format will be
/// used to encode this attribute instead.
virtual LogicalResult writeAttribute(Attribute attr,
DialectBytecodeWriter &writer) const {
return failure();
}
/// Write the given type, which belongs to this dialect, to the given writer.
/// This method may return failure to indicate that the given type could not
/// be encoded, in which case the textual format will be used to encode this
/// type instead.
virtual LogicalResult writeType(Type type,
DialectBytecodeWriter &writer) const {
return failure();
}
};
} // namespace mlir
#endif // MLIR_BYTECODE_BYTECODEIMPLEMENTATION_H