blob: 9fe1635f476024b031c442946dd2adbf749adeb4 [file] [log] [blame]
//===- Dialect.cpp - Toy IR Dialect registration in MLIR ------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements the dialect for the Toy IR: custom type parsing and
// operation verification.
//
//===----------------------------------------------------------------------===//
#include "toy/Dialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Transforms/InliningUtils.h"
using namespace mlir;
using namespace mlir::toy;
//===----------------------------------------------------------------------===//
// ToyInlinerInterface
//===----------------------------------------------------------------------===//
/// This class defines the interface for handling inlining with Toy
/// operations.
struct ToyInlinerInterface : public DialectInlinerInterface {
using DialectInlinerInterface::DialectInlinerInterface;
//===--------------------------------------------------------------------===//
// Analysis Hooks
//===--------------------------------------------------------------------===//
/// All call operations within toy can be inlined.
bool isLegalToInline(Operation *call, Operation *callable,
bool wouldBeCloned) const final {
return true;
}
/// All operations within toy can be inlined.
bool isLegalToInline(Operation *, Region *, bool,
BlockAndValueMapping &) const final {
return true;
}
//===--------------------------------------------------------------------===//
// Transformation Hooks
//===--------------------------------------------------------------------===//
/// Handle the given inlined terminator(toy.return) by replacing it with a new
/// operation as necessary.
void handleTerminator(Operation *op,
ArrayRef<Value> valuesToRepl) const final {
// Only "toy.return" needs to be handled here.
auto returnOp = cast<ReturnOp>(op);
// Replace the values directly with the return operands.
assert(returnOp.getNumOperands() == valuesToRepl.size());
for (const auto &it : llvm::enumerate(returnOp.getOperands()))
valuesToRepl[it.index()].replaceAllUsesWith(it.value());
}
/// Attempts to materialize a conversion for a type mismatch between a call
/// from this dialect, and a callable region. This method should generate an
/// operation that takes 'input' as the only operand, and produces a single
/// result of 'resultType'. If a conversion can not be generated, nullptr
/// should be returned.
Operation *materializeCallConversion(OpBuilder &builder, Value input,
Type resultType,
Location conversionLoc) const final {
return builder.create<CastOp>(conversionLoc, resultType, input);
}
};
//===----------------------------------------------------------------------===//
// ToyDialect
//===----------------------------------------------------------------------===//
/// Dialect creation, the instance will be owned by the context. This is the
/// point of registration of custom types and operations for the dialect.
ToyDialect::ToyDialect(mlir::MLIRContext *ctx)
: mlir::Dialect(getDialectNamespace(), ctx, TypeID::get<ToyDialect>()) {
addOperations<
#define GET_OP_LIST
#include "toy/Ops.cpp.inc"
>();
addInterfaces<ToyInlinerInterface>();
addTypes<StructType>();
}
mlir::Operation *ToyDialect::materializeConstant(mlir::OpBuilder &builder,
mlir::Attribute value,
mlir::Type type,
mlir::Location loc) {
if (type.isa<StructType>())
return builder.create<StructConstantOp>(loc, type,
value.cast<mlir::ArrayAttr>());
return builder.create<ConstantOp>(loc, type,
value.cast<mlir::DenseElementsAttr>());
}
//===----------------------------------------------------------------------===//
// Toy Operations
//===----------------------------------------------------------------------===//
/// A generalized parser for binary operations. This parses the different forms
/// of 'printBinaryOp' below.
static mlir::ParseResult parseBinaryOp(mlir::OpAsmParser &parser,
mlir::OperationState &result) {
SmallVector<mlir::OpAsmParser::OperandType, 2> operands;
llvm::SMLoc operandsLoc = parser.getCurrentLocation();
Type type;
if (parser.parseOperandList(operands, /*requiredOperandCount=*/2) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(type))
return mlir::failure();
// If the type is a function type, it contains the input and result types of
// this operation.
if (FunctionType funcType = type.dyn_cast<FunctionType>()) {
if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc,
result.operands))
return mlir::failure();
result.addTypes(funcType.getResults());
return mlir::success();
}
// Otherwise, the parsed type is the type of both operands and results.
if (parser.resolveOperands(operands, type, result.operands))
return mlir::failure();
result.addTypes(type);
return mlir::success();
}
/// A generalized printer for binary operations. It prints in two different
/// forms depending on if all of the types match.
static void printBinaryOp(mlir::OpAsmPrinter &printer, mlir::Operation *op) {
printer << op->getName() << " " << op->getOperands();
printer.printOptionalAttrDict(op->getAttrs());
printer << " : ";
// If all of the types are the same, print the type directly.
Type resultType = *op->result_type_begin();
if (llvm::all_of(op->getOperandTypes(),
[=](Type type) { return type == resultType; })) {
printer << resultType;
return;
}
// Otherwise, print a functional type.
printer.printFunctionalType(op->getOperandTypes(), op->getResultTypes());
}
//===----------------------------------------------------------------------===//
// ConstantOp
/// Build a constant operation.
/// The builder is passed as an argument, so is the state that this method is
/// expected to fill in order to build the operation.
void ConstantOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
double value) {
auto dataType = RankedTensorType::get({}, builder.getF64Type());
auto dataAttribute = DenseElementsAttr::get(dataType, value);
ConstantOp::build(builder, state, dataType, dataAttribute);
}
/// The 'OpAsmParser' class provides a collection of methods for parsing
/// various punctuation, as well as attributes, operands, types, etc. Each of
/// these methods returns a `ParseResult`. This class is a wrapper around
/// `LogicalResult` that can be converted to a boolean `true` value on failure,
/// or `false` on success. This allows for easily chaining together a set of
/// parser rules. These rules are used to populate an `mlir::OperationState`
/// similarly to the `build` methods described above.
static mlir::ParseResult parseConstantOp(mlir::OpAsmParser &parser,
mlir::OperationState &result) {
mlir::DenseElementsAttr value;
if (parser.parseOptionalAttrDict(result.attributes) ||
parser.parseAttribute(value, "value", result.attributes))
return failure();
result.addTypes(value.getType());
return success();
}
/// The 'OpAsmPrinter' class is a stream that allows for formatting
/// strings, attributes, operands, types, etc.
static void print(mlir::OpAsmPrinter &printer, ConstantOp op) {
printer << "toy.constant ";
printer.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"value"});
printer << op.value();
}
/// Verify that the given attribute value is valid for the given type.
static mlir::LogicalResult verifyConstantForType(mlir::Type type,
mlir::Attribute opaqueValue,
mlir::Operation *op) {
if (type.isa<mlir::TensorType>()) {
// Check that the value is an elements attribute.
auto attrValue = opaqueValue.dyn_cast<mlir::DenseFPElementsAttr>();
if (!attrValue)
return op->emitError("constant of TensorType must be initialized by "
"a DenseFPElementsAttr, got ")
<< opaqueValue;
// If the return type of the constant is not an unranked tensor, the shape
// must match the shape of the attribute holding the data.
auto resultType = type.dyn_cast<mlir::RankedTensorType>();
if (!resultType)
return success();
// Check that the rank of the attribute type matches the rank of the
// constant result type.
auto attrType = attrValue.getType().cast<mlir::TensorType>();
if (attrType.getRank() != resultType.getRank()) {
return op->emitOpError("return type must match the one of the attached "
"value attribute: ")
<< attrType.getRank() << " != " << resultType.getRank();
}
// Check that each of the dimensions match between the two types.
for (int dim = 0, dimE = attrType.getRank(); dim < dimE; ++dim) {
if (attrType.getShape()[dim] != resultType.getShape()[dim]) {
return op->emitOpError(
"return type shape mismatches its attribute at dimension ")
<< dim << ": " << attrType.getShape()[dim]
<< " != " << resultType.getShape()[dim];
}
}
return mlir::success();
}
auto resultType = type.cast<StructType>();
llvm::ArrayRef<mlir::Type> resultElementTypes = resultType.getElementTypes();
// Verify that the initializer is an Array.
auto attrValue = opaqueValue.dyn_cast<ArrayAttr>();
if (!attrValue || attrValue.getValue().size() != resultElementTypes.size())
return op->emitError("constant of StructType must be initialized by an "
"ArrayAttr with the same number of elements, got ")
<< opaqueValue;
// Check that each of the elements are valid.
llvm::ArrayRef<mlir::Attribute> attrElementValues = attrValue.getValue();
for (const auto it : llvm::zip(resultElementTypes, attrElementValues))
if (failed(verifyConstantForType(std::get<0>(it), std::get<1>(it), op)))
return mlir::failure();
return mlir::success();
}
/// Verifier for the constant operation. This corresponds to the `::verify(...)`
/// in the op definition.
static mlir::LogicalResult verify(ConstantOp op) {
return verifyConstantForType(op.getResult().getType(), op.value(), op);
}
static mlir::LogicalResult verify(StructConstantOp op) {
return verifyConstantForType(op.getResult().getType(), op.value(), op);
}
/// Infer the output shape of the ConstantOp, this is required by the shape
/// inference interface.
void ConstantOp::inferShapes() { getResult().setType(value().getType()); }
//===----------------------------------------------------------------------===//
// AddOp
void AddOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
mlir::Value lhs, mlir::Value rhs) {
state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
state.addOperands({lhs, rhs});
}
/// Infer the output shape of the AddOp, this is required by the shape inference
/// interface.
void AddOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
//===----------------------------------------------------------------------===//
// CastOp
/// Infer the output shape of the CastOp, this is required by the shape
/// inference interface.
void CastOp::inferShapes() { getResult().setType(getOperand().getType()); }
/// Returns true if the given set of input and result types are compatible with
/// this cast operation. This is required by the `CastOpInterface` to verify
/// this operation and provide other additional utilities.
bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
if (inputs.size() != 1 || outputs.size() != 1)
return false;
// The inputs must be Tensors with the same element type.
TensorType input = inputs.front().dyn_cast<TensorType>();
TensorType output = outputs.front().dyn_cast<TensorType>();
if (!input || !output || input.getElementType() != output.getElementType())
return false;
// The shape is required to match if both types are ranked.
return !input.hasRank() || !output.hasRank() || input == output;
}
//===----------------------------------------------------------------------===//
// GenericCallOp
void GenericCallOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
StringRef callee, ArrayRef<mlir::Value> arguments) {
// Generic call always returns an unranked Tensor initially.
state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
state.addOperands(arguments);
state.addAttribute("callee", builder.getSymbolRefAttr(callee));
}
/// Return the callee of the generic call operation, this is required by the
/// call interface.
CallInterfaceCallable GenericCallOp::getCallableForCallee() {
return (*this)->getAttrOfType<SymbolRefAttr>("callee");
}
/// Get the argument operands to the called function, this is required by the
/// call interface.
Operation::operand_range GenericCallOp::getArgOperands() { return inputs(); }
//===----------------------------------------------------------------------===//
// MulOp
void MulOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
mlir::Value lhs, mlir::Value rhs) {
state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
state.addOperands({lhs, rhs});
}
/// Infer the output shape of the MulOp, this is required by the shape inference
/// interface.
void MulOp::inferShapes() { getResult().setType(getOperand(0).getType()); }
//===----------------------------------------------------------------------===//
// ReturnOp
static mlir::LogicalResult verify(ReturnOp op) {
// We know that the parent operation is a function, because of the 'HasParent'
// trait attached to the operation definition.
auto function = cast<FuncOp>(op->getParentOp());
/// ReturnOps can only have a single optional operand.
if (op.getNumOperands() > 1)
return op.emitOpError() << "expects at most 1 return operand";
// The operand number and types must match the function signature.
const auto &results = function.getType().getResults();
if (op.getNumOperands() != results.size())
return op.emitOpError()
<< "does not return the same number of values ("
<< op.getNumOperands() << ") as the enclosing function ("
<< results.size() << ")";
// If the operation does not have an input, we are done.
if (!op.hasOperand())
return mlir::success();
auto inputType = *op.operand_type_begin();
auto resultType = results.front();
// Check that the result type of the function matches the operand type.
if (inputType == resultType || inputType.isa<mlir::UnrankedTensorType>() ||
resultType.isa<mlir::UnrankedTensorType>())
return mlir::success();
return op.emitError() << "type of return operand (" << inputType
<< ") doesn't match function result type ("
<< resultType << ")";
}
//===----------------------------------------------------------------------===//
// StructAccessOp
void StructAccessOp::build(mlir::OpBuilder &b, mlir::OperationState &state,
mlir::Value input, size_t index) {
// Extract the result type from the input type.
StructType structTy = input.getType().cast<StructType>();
assert(index < structTy.getNumElementTypes());
mlir::Type resultType = structTy.getElementTypes()[index];
// Call into the auto-generated build method.
build(b, state, resultType, input, b.getI64IntegerAttr(index));
}
static mlir::LogicalResult verify(StructAccessOp op) {
StructType structTy = op.input().getType().cast<StructType>();
size_t index = op.index();
if (index >= structTy.getNumElementTypes())
return op.emitOpError()
<< "index should be within the range of the input struct type";
mlir::Type resultType = op.getResult().getType();
if (resultType != structTy.getElementTypes()[index])
return op.emitOpError() << "must have the same result type as the struct "
"element referred to by the index";
return mlir::success();
}
//===----------------------------------------------------------------------===//
// TransposeOp
void TransposeOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
mlir::Value value) {
state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
state.addOperands(value);
}
void TransposeOp::inferShapes() {
auto arrayTy = getOperand().getType().cast<RankedTensorType>();
SmallVector<int64_t, 2> dims(llvm::reverse(arrayTy.getShape()));
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
}
static mlir::LogicalResult verify(TransposeOp op) {
auto inputType = op.getOperand().getType().dyn_cast<RankedTensorType>();
auto resultType = op.getType().dyn_cast<RankedTensorType>();
if (!inputType || !resultType)
return mlir::success();
auto inputShape = inputType.getShape();
if (!std::equal(inputShape.begin(), inputShape.end(),
resultType.getShape().rbegin())) {
return op.emitError()
<< "expected result shape to be a transpose of the input";
}
return mlir::success();
}
//===----------------------------------------------------------------------===//
// Toy Types
//===----------------------------------------------------------------------===//
namespace mlir {
namespace toy {
namespace detail {
/// This class represents the internal storage of the Toy `StructType`.
struct StructTypeStorage : public mlir::TypeStorage {
/// The `KeyTy` is a required type that provides an interface for the storage
/// instance. This type will be used when uniquing an instance of the type
/// storage. For our struct type, we will unique each instance structurally on
/// the elements that it contains.
using KeyTy = llvm::ArrayRef<mlir::Type>;
/// A constructor for the type storage instance.
StructTypeStorage(llvm::ArrayRef<mlir::Type> elementTypes)
: elementTypes(elementTypes) {}
/// Define the comparison function for the key type with the current storage
/// instance. This is used when constructing a new instance to ensure that we
/// haven't already uniqued an instance of the given key.
bool operator==(const KeyTy &key) const { return key == elementTypes; }
/// Define a hash function for the key type. This is used when uniquing
/// instances of the storage, see the `StructType::get` method.
/// Note: This method isn't necessary as both llvm::ArrayRef and mlir::Type
/// have hash functions available, so we could just omit this entirely.
static llvm::hash_code hashKey(const KeyTy &key) {
return llvm::hash_value(key);
}
/// Define a construction function for the key type from a set of parameters.
/// These parameters will be provided when constructing the storage instance
/// itself.
/// Note: This method isn't necessary because KeyTy can be directly
/// constructed with the given parameters.
static KeyTy getKey(llvm::ArrayRef<mlir::Type> elementTypes) {
return KeyTy(elementTypes);
}
/// Define a construction method for creating a new instance of this storage.
/// This method takes an instance of a storage allocator, and an instance of a
/// `KeyTy`. The given allocator must be used for *all* necessary dynamic
/// allocations used to create the type storage and its internal.
static StructTypeStorage *construct(mlir::TypeStorageAllocator &allocator,
const KeyTy &key) {
// Copy the elements from the provided `KeyTy` into the allocator.
llvm::ArrayRef<mlir::Type> elementTypes = allocator.copyInto(key);
// Allocate the storage instance and construct it.
return new (allocator.allocate<StructTypeStorage>())
StructTypeStorage(elementTypes);
}
/// The following field contains the element types of the struct.
llvm::ArrayRef<mlir::Type> elementTypes;
};
} // end namespace detail
} // end namespace toy
} // end namespace mlir
/// Create an instance of a `StructType` with the given element types. There
/// *must* be at least one element type.
StructType StructType::get(llvm::ArrayRef<mlir::Type> elementTypes) {
assert(!elementTypes.empty() && "expected at least 1 element type");
// Call into a helper 'get' method in 'TypeBase' to get a uniqued instance
// of this type. The first parameter is the context to unique in. The
// parameters after the context are forwarded to the storage instance.
mlir::MLIRContext *ctx = elementTypes.front().getContext();
return Base::get(ctx, elementTypes);
}
/// Returns the element types of this struct type.
llvm::ArrayRef<mlir::Type> StructType::getElementTypes() {
// 'getImpl' returns a pointer to the internal storage instance.
return getImpl()->elementTypes;
}
/// Parse an instance of a type registered to the toy dialect.
mlir::Type ToyDialect::parseType(mlir::DialectAsmParser &parser) const {
// Parse a struct type in the following form:
// struct-type ::= `struct` `<` type (`,` type)* `>`
// NOTE: All MLIR parser function return a ParseResult. This is a
// specialization of LogicalResult that auto-converts to a `true` boolean
// value on failure to allow for chaining, but may be used with explicit
// `mlir::failed/mlir::succeeded` as desired.
// Parse: `struct` `<`
if (parser.parseKeyword("struct") || parser.parseLess())
return Type();
// Parse the element types of the struct.
SmallVector<mlir::Type, 1> elementTypes;
do {
// Parse the current element type.
llvm::SMLoc typeLoc = parser.getCurrentLocation();
mlir::Type elementType;
if (parser.parseType(elementType))
return nullptr;
// Check that the type is either a TensorType or another StructType.
if (!elementType.isa<mlir::TensorType, StructType>()) {
parser.emitError(typeLoc, "element type for a struct must either "
"be a TensorType or a StructType, got: ")
<< elementType;
return Type();
}
elementTypes.push_back(elementType);
// Parse the optional: `,`
} while (succeeded(parser.parseOptionalComma()));
// Parse: `>`
if (parser.parseGreater())
return Type();
return StructType::get(elementTypes);
}
/// Print an instance of a type registered to the toy dialect.
void ToyDialect::printType(mlir::Type type,
mlir::DialectAsmPrinter &printer) const {
// Currently the only toy type is a struct type.
StructType structType = type.cast<StructType>();
// Print the struct type according to the parser format.
printer << "struct<";
llvm::interleaveComma(structType.getElementTypes(), printer);
printer << '>';
}
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "toy/Ops.cpp.inc"