| //===- OpClass.cpp - Helper classes for Op C++ code emission --------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/TableGen/OpClass.h" |
| |
| #include "mlir/TableGen/Format.h" |
| #include "llvm/ADT/Sequence.h" |
| #include "llvm/ADT/Twine.h" |
| #include "llvm/Support/Debug.h" |
| #include "llvm/Support/raw_ostream.h" |
| #include <unordered_set> |
| |
| #define DEBUG_TYPE "mlir-tblgen-opclass" |
| |
| using namespace mlir; |
| using namespace mlir::tblgen; |
| |
| namespace { |
| |
| // Returns space to be emitted after the given C++ `type`. return "" if the |
| // ends with '&' or '*', or is empty, else returns " ". |
| StringRef getSpaceAfterType(StringRef type) { |
| return (type.empty() || type.endswith("&") || type.endswith("*")) ? "" : " "; |
| } |
| |
| } // namespace |
| |
| //===----------------------------------------------------------------------===// |
| // OpMethodParameter definitions |
| //===----------------------------------------------------------------------===// |
| |
| void OpMethodParameter::writeTo(raw_ostream &os, bool emitDefault) const { |
| if (properties & PP_Optional) |
| os << "/*optional*/"; |
| os << type << getSpaceAfterType(type) << name; |
| if (emitDefault && !defaultValue.empty()) |
| os << " = " << defaultValue; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // OpMethodParameters definitions |
| //===----------------------------------------------------------------------===// |
| |
| // Factory methods to construct the correct type of `OpMethodParameters` |
| // object based on the arguments. |
| std::unique_ptr<OpMethodParameters> OpMethodParameters::create() { |
| return std::make_unique<OpMethodResolvedParameters>(); |
| } |
| |
| std::unique_ptr<OpMethodParameters> |
| OpMethodParameters::create(StringRef params) { |
| return std::make_unique<OpMethodUnresolvedParameters>(params); |
| } |
| |
| std::unique_ptr<OpMethodParameters> |
| OpMethodParameters::create(llvm::SmallVectorImpl<OpMethodParameter> &¶ms) { |
| return std::make_unique<OpMethodResolvedParameters>(std::move(params)); |
| } |
| |
| std::unique_ptr<OpMethodParameters> |
| OpMethodParameters::create(StringRef type, StringRef name, |
| StringRef defaultValue) { |
| return std::make_unique<OpMethodResolvedParameters>(type, name, defaultValue); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // OpMethodUnresolvedParameters definitions |
| //===----------------------------------------------------------------------===// |
| void OpMethodUnresolvedParameters::writeDeclTo(raw_ostream &os) const { |
| os << parameters; |
| } |
| |
| void OpMethodUnresolvedParameters::writeDefTo(raw_ostream &os) const { |
| // We need to remove the default values for parameters in method definition. |
| // TODO: We are using '=' and ',' as delimiters for parameter |
| // initializers. This is incorrect for initializer list with more than one |
| // element. Change to a more robust approach. |
| llvm::SmallVector<StringRef, 4> tokens; |
| StringRef params = parameters; |
| while (!params.empty()) { |
| std::pair<StringRef, StringRef> parts = params.split("="); |
| tokens.push_back(parts.first); |
| params = parts.second.split(',').second; |
| } |
| llvm::interleaveComma(tokens, os, [&](StringRef token) { os << token; }); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // OpMethodResolvedParameters definitions |
| //===----------------------------------------------------------------------===// |
| |
| // Returns true if a method with these parameters makes a method with parameters |
| // `other` redundant. This should return true only if all possible calls to the |
| // other method can be replaced by calls to this method. |
| bool OpMethodResolvedParameters::makesRedundant( |
| const OpMethodResolvedParameters &other) const { |
| const size_t otherNumParams = other.getNumParameters(); |
| const size_t thisNumParams = getNumParameters(); |
| |
| // All calls to the other method can be replaced this method only if this |
| // method has the same or more arguments number of arguments as the other, and |
| // the common arguments have the same type. |
| if (thisNumParams < otherNumParams) |
| return false; |
| for (int idx : llvm::seq<int>(0, otherNumParams)) |
| if (parameters[idx].getType() != other.parameters[idx].getType()) |
| return false; |
| |
| // If all the common arguments have the same type, we can elide the other |
| // method if this method has the same number of arguments as other or the |
| // first argument after the common ones has a default value (and by C++ |
| // requirement, all the later ones will also have a default value). |
| return thisNumParams == otherNumParams || |
| parameters[otherNumParams].hasDefaultValue(); |
| } |
| |
| void OpMethodResolvedParameters::writeDeclTo(raw_ostream &os) const { |
| llvm::interleaveComma(parameters, os, [&](const OpMethodParameter ¶m) { |
| param.writeDeclTo(os); |
| }); |
| } |
| |
| void OpMethodResolvedParameters::writeDefTo(raw_ostream &os) const { |
| llvm::interleaveComma(parameters, os, [&](const OpMethodParameter ¶m) { |
| param.writeDefTo(os); |
| }); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // OpMethodSignature definitions |
| //===----------------------------------------------------------------------===// |
| |
| // Returns if a method with this signature makes a method with `other` signature |
| // redundant. Only supports resolved parameters. |
| bool OpMethodSignature::makesRedundant(const OpMethodSignature &other) const { |
| if (methodName != other.methodName) |
| return false; |
| auto *resolvedThis = dyn_cast<OpMethodResolvedParameters>(parameters.get()); |
| auto *resolvedOther = |
| dyn_cast<OpMethodResolvedParameters>(other.parameters.get()); |
| if (resolvedThis && resolvedOther) |
| return resolvedThis->makesRedundant(*resolvedOther); |
| return false; |
| } |
| |
| void OpMethodSignature::writeDeclTo(raw_ostream &os) const { |
| os << returnType << getSpaceAfterType(returnType) << methodName << "("; |
| parameters->writeDeclTo(os); |
| os << ")"; |
| } |
| |
| void OpMethodSignature::writeDefTo(raw_ostream &os, |
| StringRef namePrefix) const { |
| os << returnType << getSpaceAfterType(returnType) << namePrefix |
| << (namePrefix.empty() ? "" : "::") << methodName << "("; |
| parameters->writeDefTo(os); |
| os << ")"; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // OpMethodBody definitions |
| //===----------------------------------------------------------------------===// |
| |
| OpMethodBody::OpMethodBody(bool declOnly) : isEffective(!declOnly) {} |
| |
| OpMethodBody &OpMethodBody::operator<<(Twine content) { |
| if (isEffective) |
| body.append(content.str()); |
| return *this; |
| } |
| |
| OpMethodBody &OpMethodBody::operator<<(int content) { |
| if (isEffective) |
| body.append(std::to_string(content)); |
| return *this; |
| } |
| |
| OpMethodBody &OpMethodBody::operator<<(const FmtObjectBase &content) { |
| if (isEffective) |
| body.append(content.str()); |
| return *this; |
| } |
| |
| void OpMethodBody::writeTo(raw_ostream &os) const { |
| auto bodyRef = StringRef(body).drop_while([](char c) { return c == '\n'; }); |
| os << bodyRef; |
| if (bodyRef.empty() || bodyRef.back() != '\n') |
| os << "\n"; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // OpMethod definitions |
| //===----------------------------------------------------------------------===// |
| |
| void OpMethod::writeDeclTo(raw_ostream &os) const { |
| os.indent(2); |
| if (isStatic()) |
| os << "static "; |
| methodSignature.writeDeclTo(os); |
| os << ";"; |
| } |
| |
| void OpMethod::writeDefTo(raw_ostream &os, StringRef namePrefix) const { |
| // Do not write definition if the method is decl only. |
| if (properties & MP_Declaration) |
| return; |
| methodSignature.writeDefTo(os, namePrefix); |
| os << " {\n"; |
| methodBody.writeTo(os); |
| os << "}"; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // OpConstructor definitions |
| //===----------------------------------------------------------------------===// |
| |
| void OpConstructor::addMemberInitializer(StringRef name, StringRef value) { |
| memberInitializers.append(std::string(llvm::formatv( |
| "{0}{1}({2})", memberInitializers.empty() ? " : " : ", ", name, value))); |
| } |
| |
| void OpConstructor::writeDefTo(raw_ostream &os, StringRef namePrefix) const { |
| // Do not write definition if the method is decl only. |
| if (properties & MP_Declaration) |
| return; |
| |
| methodSignature.writeDefTo(os, namePrefix); |
| os << " " << memberInitializers << " {\n"; |
| methodBody.writeTo(os); |
| os << "}"; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Class definitions |
| //===----------------------------------------------------------------------===// |
| |
| Class::Class(StringRef name) : className(name) {} |
| |
| void Class::newField(StringRef type, StringRef name, StringRef defaultValue) { |
| std::string varName = formatv("{0} {1}", type, name).str(); |
| std::string field = defaultValue.empty() |
| ? varName |
| : formatv("{0} = {1}", varName, defaultValue).str(); |
| fields.push_back(std::move(field)); |
| } |
| void Class::writeDeclTo(raw_ostream &os) const { |
| bool hasPrivateMethod = false; |
| os << "class " << className << " {\n"; |
| os << "public:\n"; |
| |
| forAllMethods([&](const OpMethod &method) { |
| if (!method.isPrivate()) { |
| method.writeDeclTo(os); |
| os << '\n'; |
| } else { |
| hasPrivateMethod = true; |
| } |
| }); |
| |
| os << '\n'; |
| os << "private:\n"; |
| if (hasPrivateMethod) { |
| forAllMethods([&](const OpMethod &method) { |
| if (method.isPrivate()) { |
| method.writeDeclTo(os); |
| os << '\n'; |
| } |
| }); |
| os << '\n'; |
| } |
| |
| for (const auto &field : fields) |
| os.indent(2) << field << ";\n"; |
| os << "};\n"; |
| } |
| |
| void Class::writeDefTo(raw_ostream &os) const { |
| forAllMethods([&](const OpMethod &method) { |
| method.writeDefTo(os, className); |
| os << "\n\n"; |
| }); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // OpClass definitions |
| //===----------------------------------------------------------------------===// |
| |
| OpClass::OpClass(StringRef name, StringRef extraClassDeclaration) |
| : Class(name), extraClassDeclaration(extraClassDeclaration) {} |
| |
| void OpClass::addTrait(Twine trait) { |
| auto traitStr = trait.str(); |
| if (traitsSet.insert(traitStr).second) |
| traitsVec.push_back(std::move(traitStr)); |
| } |
| |
| void OpClass::writeDeclTo(raw_ostream &os) const { |
| os << "class " << className << " : public ::mlir::Op<" << className; |
| for (const auto &trait : traitsVec) |
| os << ", " << trait; |
| os << "> {\npublic:\n"; |
| os << " using Op::Op;\n"; |
| os << " using Adaptor = " << className << "Adaptor;\n"; |
| |
| bool hasPrivateMethod = false; |
| forAllMethods([&](const OpMethod &method) { |
| if (!method.isPrivate()) { |
| method.writeDeclTo(os); |
| os << "\n"; |
| } else { |
| hasPrivateMethod = true; |
| } |
| }); |
| |
| // TODO: Add line control markers to make errors easier to debug. |
| if (!extraClassDeclaration.empty()) |
| os << extraClassDeclaration << "\n"; |
| |
| if (hasPrivateMethod) { |
| os << "\nprivate:\n"; |
| forAllMethods([&](const OpMethod &method) { |
| if (method.isPrivate()) { |
| method.writeDeclTo(os); |
| os << "\n"; |
| } |
| }); |
| } |
| |
| os << "};\n"; |
| } |