blob: 9b65f2fc06a379af520a063e09f85bbcdb874c06 [file] [log] [blame]
//===- TestTypes.cpp - MLIR Test Dialect Types ------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains types defined by the TestDialect for testing various
// features of MLIR.
//
//===----------------------------------------------------------------------===//
#include "TestTypes.h"
#include "TestDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Types.h"
#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
using namespace test;
// Custom parser for SignednessSemantics.
static ParseResult
parseSignedness(AsmParser &parser,
TestIntegerType::SignednessSemantics &result) {
StringRef signStr;
auto loc = parser.getCurrentLocation();
if (parser.parseKeyword(&signStr))
return failure();
if (signStr.equals_insensitive("u") || signStr.equals_insensitive("unsigned"))
result = TestIntegerType::SignednessSemantics::Unsigned;
else if (signStr.equals_insensitive("s") ||
signStr.equals_insensitive("signed"))
result = TestIntegerType::SignednessSemantics::Signed;
else if (signStr.equals_insensitive("n") ||
signStr.equals_insensitive("none"))
result = TestIntegerType::SignednessSemantics::Signless;
else
return parser.emitError(loc, "expected signed, unsigned, or none");
return success();
}
// Custom printer for SignednessSemantics.
static void printSignedness(AsmPrinter &printer,
const TestIntegerType::SignednessSemantics &ss) {
switch (ss) {
case TestIntegerType::SignednessSemantics::Unsigned:
printer << "unsigned";
break;
case TestIntegerType::SignednessSemantics::Signed:
printer << "signed";
break;
case TestIntegerType::SignednessSemantics::Signless:
printer << "none";
break;
}
}
// The functions don't need to be in the header file, but need to be in the mlir
// namespace. Declare them here, then define them immediately below. Separating
// the declaration and definition adheres to the LLVM coding standards.
namespace test {
// FieldInfo is used as part of a parameter, so equality comparison is
// compulsory.
static bool operator==(const FieldInfo &a, const FieldInfo &b);
// FieldInfo is used as part of a parameter, so a hash will be computed.
static llvm::hash_code hash_value(const FieldInfo &fi); // NOLINT
} // namespace test
// FieldInfo is used as part of a parameter, so equality comparison is
// compulsory.
static bool test::operator==(const FieldInfo &a, const FieldInfo &b) {
return a.name == b.name && a.type == b.type;
}
// FieldInfo is used as part of a parameter, so a hash will be computed.
static llvm::hash_code test::hash_value(const FieldInfo &fi) { // NOLINT
return llvm::hash_combine(fi.name, fi.type);
}
//===----------------------------------------------------------------------===//
// CompoundAType
//===----------------------------------------------------------------------===//
Type CompoundAType::parse(AsmParser &parser) {
int widthOfSomething;
Type oneType;
SmallVector<int, 4> arrayOfInts;
if (parser.parseLess() || parser.parseInteger(widthOfSomething) ||
parser.parseComma() || parser.parseType(oneType) || parser.parseComma() ||
parser.parseLSquare())
return Type();
int i;
while (!*parser.parseOptionalInteger(i)) {
arrayOfInts.push_back(i);
if (parser.parseOptionalComma())
break;
}
if (parser.parseRSquare() || parser.parseGreater())
return Type();
return get(parser.getContext(), widthOfSomething, oneType, arrayOfInts);
}
void CompoundAType::print(AsmPrinter &printer) const {
printer << "<" << getWidthOfSomething() << ", " << getOneType() << ", [";
auto intArray = getArrayOfInts();
llvm::interleaveComma(intArray, printer);
printer << "]>";
}
//===----------------------------------------------------------------------===//
// TestIntegerType
//===----------------------------------------------------------------------===//
// Example type validity checker.
LogicalResult
TestIntegerType::verify(function_ref<InFlightDiagnostic()> emitError,
unsigned width,
TestIntegerType::SignednessSemantics ss) {
if (width > 8)
return failure();
return success();
}
//===----------------------------------------------------------------------===//
// TestType
//===----------------------------------------------------------------------===//
void TestType::printTypeC(Location loc) const {
emitRemark(loc) << *this << " - TestC";
}
//===----------------------------------------------------------------------===//
// TestTypeWithLayout
//===----------------------------------------------------------------------===//
Type TestTypeWithLayoutType::parse(AsmParser &parser) {
unsigned val;
if (parser.parseLess() || parser.parseInteger(val) || parser.parseGreater())
return Type();
return TestTypeWithLayoutType::get(parser.getContext(), val);
}
void TestTypeWithLayoutType::print(AsmPrinter &printer) const {
printer << "<" << getKey() << ">";
}
unsigned
TestTypeWithLayoutType::getTypeSizeInBits(const DataLayout &dataLayout,
DataLayoutEntryListRef params) const {
return extractKind(params, "size");
}
unsigned
TestTypeWithLayoutType::getABIAlignment(const DataLayout &dataLayout,
DataLayoutEntryListRef params) const {
return extractKind(params, "alignment");
}
unsigned TestTypeWithLayoutType::getPreferredAlignment(
const DataLayout &dataLayout, DataLayoutEntryListRef params) const {
return extractKind(params, "preferred");
}
bool TestTypeWithLayoutType::areCompatible(
DataLayoutEntryListRef oldLayout, DataLayoutEntryListRef newLayout) const {
unsigned old = extractKind(oldLayout, "alignment");
return old == 1 || extractKind(newLayout, "alignment") <= old;
}
LogicalResult
TestTypeWithLayoutType::verifyEntries(DataLayoutEntryListRef params,
Location loc) const {
for (DataLayoutEntryInterface entry : params) {
// This is for testing purposes only, so assert well-formedness.
assert(entry.isTypeEntry() && "unexpected identifier entry");
assert(entry.getKey().get<Type>().isa<TestTypeWithLayoutType>() &&
"wrong type passed in");
auto array = entry.getValue().dyn_cast<ArrayAttr>();
assert(array && array.getValue().size() == 2 &&
"expected array of two elements");
auto kind = array.getValue().front().dyn_cast<StringAttr>();
(void)kind;
assert(kind &&
(kind.getValue() == "size" || kind.getValue() == "alignment" ||
kind.getValue() == "preferred") &&
"unexpected kind");
assert(array.getValue().back().isa<IntegerAttr>());
}
return success();
}
unsigned TestTypeWithLayoutType::extractKind(DataLayoutEntryListRef params,
StringRef expectedKind) const {
for (DataLayoutEntryInterface entry : params) {
ArrayRef<Attribute> pair = entry.getValue().cast<ArrayAttr>().getValue();
StringRef kind = pair.front().cast<StringAttr>().getValue();
if (kind == expectedKind)
return pair.back().cast<IntegerAttr>().getValue().getZExtValue();
}
return 1;
}
//===----------------------------------------------------------------------===//
// Tablegen Generated Definitions
//===----------------------------------------------------------------------===//
#define GET_TYPEDEF_CLASSES
#include "TestTypeDefs.cpp.inc"
//===----------------------------------------------------------------------===//
// TestDialect
//===----------------------------------------------------------------------===//
namespace {
struct PtrElementModel
: public LLVM::PointerElementTypeInterface::ExternalModel<PtrElementModel,
SimpleAType> {};
} // namespace
void TestDialect::registerTypes() {
addTypes<TestRecursiveType,
#define GET_TYPEDEF_LIST
#include "TestTypeDefs.cpp.inc"
>();
SimpleAType::attachInterface<PtrElementModel>(*getContext());
}
static Type parseTestType(AsmParser &parser, SetVector<Type> &stack) {
StringRef typeTag;
if (failed(parser.parseKeyword(&typeTag)))
return Type();
{
Type genType;
auto parseResult = generatedTypeParser(parser, typeTag, genType);
if (parseResult.hasValue())
return genType;
}
if (typeTag != "test_rec") {
parser.emitError(parser.getNameLoc()) << "unknown type!";
return Type();
}
StringRef name;
if (parser.parseLess() || parser.parseKeyword(&name))
return Type();
auto rec = TestRecursiveType::get(parser.getContext(), name);
// If this type already has been parsed above in the stack, expect just the
// name.
if (stack.contains(rec)) {
if (failed(parser.parseGreater()))
return Type();
return rec;
}
// Otherwise, parse the body and update the type.
if (failed(parser.parseComma()))
return Type();
stack.insert(rec);
Type subtype = parseTestType(parser, stack);
stack.pop_back();
if (!subtype || failed(parser.parseGreater()) || failed(rec.setBody(subtype)))
return Type();
return rec;
}
Type TestDialect::parseType(DialectAsmParser &parser) const {
SetVector<Type> stack;
return parseTestType(parser, stack);
}
static void printTestType(Type type, AsmPrinter &printer,
SetVector<Type> &stack) {
if (succeeded(generatedTypePrinter(type, printer)))
return;
auto rec = type.cast<TestRecursiveType>();
printer << "test_rec<" << rec.getName();
if (!stack.contains(rec)) {
printer << ", ";
stack.insert(rec);
printTestType(rec.getBody(), printer, stack);
stack.pop_back();
}
printer << ">";
}
void TestDialect::printType(Type type, DialectAsmPrinter &printer) const {
SetVector<Type> stack;
printTestType(type, printer, stack);
}