blob: a55f3539f31db004f94ea6c89d673ddf754518ba [file] [log] [blame]
//===- OpDefinitionsGen.cpp - IRDL op definitions generator ---------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// OpDefinitionsGen uses the description of operations to generate IRDL
// definitions for ops.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/IRDL/IR/IRDL.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/TableGen/AttrOrTypeDef.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/GenNameParser.h"
#include "mlir/TableGen/Interfaces.h"
#include "mlir/TableGen/Operator.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/TableGen/Main.h"
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"
using namespace llvm;
using namespace mlir;
using tblgen::NamedTypeConstraint;
static llvm::cl::OptionCategory dialectGenCat("Options for -gen-irdl-dialect");
llvm::cl::opt<std::string>
selectedDialect("dialect", llvm::cl::desc("The dialect to gen for"),
llvm::cl::cat(dialectGenCat), llvm::cl::Required);
Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
MLIRContext *ctx = builder.getContext();
const Record &predRec = constraint.getDef();
if (predRec.isSubClassOf("Variadic") || predRec.isSubClassOf("Optional"))
return createConstraint(builder, predRec.getValueAsDef("baseType"));
if (predRec.getName() == "AnyType") {
auto op = builder.create<irdl::AnyOp>(UnknownLoc::get(ctx));
return op.getOutput();
}
if (predRec.isSubClassOf("TypeDef")) {
std::string typeName = ("!" + predRec.getValueAsString("typeName")).str();
auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx),
StringAttr::get(ctx, typeName));
return op.getOutput();
}
if (predRec.isSubClassOf("AnyTypeOf")) {
std::vector<Value> constraints;
for (Record *child : predRec.getValueAsListOfDefs("allowedTypes")) {
constraints.push_back(
createConstraint(builder, tblgen::Constraint(child)));
}
auto op = builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), constraints);
return op.getOutput();
}
if (predRec.isSubClassOf("AllOfType")) {
std::vector<Value> constraints;
for (Record *child : predRec.getValueAsListOfDefs("allowedTypes")) {
constraints.push_back(
createConstraint(builder, tblgen::Constraint(child)));
}
auto op = builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints);
return op.getOutput();
}
std::string condition = constraint.getPredicate().getCondition();
// Build a CPredOp to match the C constraint built.
irdl::CPredOp op = builder.create<irdl::CPredOp>(
UnknownLoc::get(ctx), StringAttr::get(ctx, condition));
return op;
}
/// Returns the name of the operation without the dialect prefix.
static StringRef getOperatorName(tblgen::Operator &tblgenOp) {
StringRef opName = tblgenOp.getDef().getValueAsString("opName");
return opName;
}
/// Extract an operation to IRDL.
irdl::OperationOp createIRDLOperation(OpBuilder &builder,
tblgen::Operator &tblgenOp) {
MLIRContext *ctx = builder.getContext();
StringRef opName = getOperatorName(tblgenOp);
irdl::OperationOp op = builder.create<irdl::OperationOp>(
UnknownLoc::get(ctx), StringAttr::get(ctx, opName));
// Add the block in the region.
Block &opBlock = op.getBody().emplaceBlock();
OpBuilder consBuilder = OpBuilder::atBlockBegin(&opBlock);
auto getValues = [&](tblgen::Operator::const_value_range namedCons) {
SmallVector<Value> operands;
SmallVector<irdl::VariadicityAttr> variadicity;
for (const NamedTypeConstraint &namedCons : namedCons) {
auto operand = createConstraint(consBuilder, namedCons.constraint);
operands.push_back(operand);
irdl::VariadicityAttr var;
if (namedCons.isOptional())
var = consBuilder.getAttr<irdl::VariadicityAttr>(
irdl::Variadicity::optional);
else if (namedCons.isVariadic())
var = consBuilder.getAttr<irdl::VariadicityAttr>(
irdl::Variadicity::variadic);
else
var = consBuilder.getAttr<irdl::VariadicityAttr>(
irdl::Variadicity::single);
variadicity.push_back(var);
}
return std::make_tuple(operands, variadicity);
};
auto [operands, operandVariadicity] = getValues(tblgenOp.getOperands());
auto [results, resultVariadicity] = getValues(tblgenOp.getResults());
// Create the operands and results operations.
consBuilder.create<irdl::OperandsOp>(UnknownLoc::get(ctx), operands,
operandVariadicity);
consBuilder.create<irdl::ResultsOp>(UnknownLoc::get(ctx), results,
resultVariadicity);
return op;
}
static irdl::DialectOp createIRDLDialect(OpBuilder &builder) {
MLIRContext *ctx = builder.getContext();
return builder.create<irdl::DialectOp>(UnknownLoc::get(ctx),
StringAttr::get(ctx, selectedDialect));
}
static std::vector<llvm::Record *>
getOpDefinitions(const RecordKeeper &recordKeeper) {
if (!recordKeeper.getClass("Op"))
return {};
return recordKeeper.getAllDerivedDefinitions("Op");
}
static bool emitDialectIRDLDefs(const RecordKeeper &recordKeeper,
raw_ostream &os) {
// Initialize.
MLIRContext ctx;
ctx.getOrLoadDialect<irdl::IRDLDialect>();
OpBuilder builder(&ctx);
// Create a module op and set it as the insertion point.
OwningOpRef<ModuleOp> module =
builder.create<ModuleOp>(UnknownLoc::get(&ctx));
builder = builder.atBlockBegin(module->getBody());
// Create the dialect and insert it.
irdl::DialectOp dialect = createIRDLDialect(builder);
// Set insertion point to start of DialectOp.
builder = builder.atBlockBegin(&dialect.getBody().emplaceBlock());
std::vector<Record *> defs = getOpDefinitions(recordKeeper);
for (auto *def : defs) {
tblgen::Operator tblgenOp(def);
if (tblgenOp.getDialectName() != selectedDialect)
continue;
createIRDLOperation(builder, tblgenOp);
}
// Print the module.
module->print(os);
return false;
}
static mlir::GenRegistration
genOpDefs("gen-dialect-irdl-defs", "Generate IRDL dialect definitions",
[](const RecordKeeper &records, raw_ostream &os) {
return emitDialectIRDLDefs(records, os);
});