blob: 0014153cceaa3cfbad3cdb584a786870bfb84999 [file] [log] [blame]
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "UseNewMLIROpBuilderCheck.h"
#include "clang/ASTMatchers/ASTMatchers.h"
#include "clang/Basic/LLVM.h"
#include "clang/Lex/Lexer.h"
#include "clang/Tooling/Transformer/RangeSelector.h"
#include "clang/Tooling/Transformer/RewriteRule.h"
#include "clang/Tooling/Transformer/SourceCode.h"
#include "clang/Tooling/Transformer/Stencil.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/FormatVariadic.h"
namespace clang::tidy::llvm_check {
using namespace ::clang::ast_matchers;
using namespace ::clang::transformer;
static EditGenerator rewrite(RangeSelector Call, RangeSelector Builder) {
// This is using an EditGenerator rather than ASTEdit as we want to warn even
// if in macro.
return [Call = std::move(Call),
Builder = std::move(Builder)](const MatchFinder::MatchResult &Result)
-> Expected<SmallVector<transformer::Edit, 1>> {
Expected<CharSourceRange> CallRange = Call(Result);
if (!CallRange)
return CallRange.takeError();
SourceManager &SM = *Result.SourceManager;
const LangOptions &LangOpts = Result.Context->getLangOpts();
SourceLocation Begin = CallRange->getBegin();
// This will result in just a warning and no edit.
bool InMacro = CallRange->getBegin().isMacroID();
if (InMacro) {
while (SM.isMacroArgExpansion(Begin))
Begin = SM.getImmediateExpansionRange(Begin).getBegin();
Edit WarnOnly;
WarnOnly.Kind = EditKind::Range;
WarnOnly.Range = CharSourceRange::getCharRange(Begin, Begin);
return SmallVector<Edit, 1>({WarnOnly});
}
// This will try to extract the template argument as written so that the
// rewritten code looks closest to original.
auto NextToken = [&](std::optional<Token> CurrentToken) {
if (!CurrentToken)
return CurrentToken;
if (CurrentToken->is(clang::tok::eof))
return std::optional<Token>();
return clang::Lexer::findNextToken(CurrentToken->getLocation(), SM,
LangOpts);
};
std::optional<Token> LessToken =
clang::Lexer::findNextToken(Begin, SM, LangOpts);
while (LessToken && LessToken->getKind() != clang::tok::less) {
LessToken = NextToken(LessToken);
}
if (!LessToken) {
return llvm::make_error<llvm::StringError>(llvm::errc::invalid_argument,
"missing '<' token");
}
std::optional<Token> EndToken = NextToken(LessToken);
std::optional<Token> GreaterToken = NextToken(EndToken);
for (; GreaterToken && GreaterToken->getKind() != clang::tok::greater;
GreaterToken = NextToken(GreaterToken)) {
EndToken = GreaterToken;
}
if (!EndToken) {
return llvm::make_error<llvm::StringError>(llvm::errc::invalid_argument,
"missing '>' token");
}
std::optional<Token> ArgStart = NextToken(GreaterToken);
if (!ArgStart || ArgStart->getKind() != clang::tok::l_paren) {
return llvm::make_error<llvm::StringError>(llvm::errc::invalid_argument,
"missing '(' token");
}
std::optional<Token> Arg = NextToken(ArgStart);
if (!Arg) {
return llvm::make_error<llvm::StringError>(llvm::errc::invalid_argument,
"unexpected end of file");
}
const bool HasArgs = Arg->getKind() != clang::tok::r_paren;
Expected<CharSourceRange> BuilderRange = Builder(Result);
if (!BuilderRange)
return BuilderRange.takeError();
// Helper for concatting below.
auto GetText = [&](const CharSourceRange &Range) {
return clang::Lexer::getSourceText(Range, SM, LangOpts);
};
Edit Replace;
Replace.Kind = EditKind::Range;
Replace.Range.setBegin(CallRange->getBegin());
Replace.Range.setEnd(ArgStart->getEndLoc());
const Expr *BuilderExpr = Result.Nodes.getNodeAs<Expr>("builder");
std::string BuilderText = GetText(*BuilderRange).str();
if (BuilderExpr->getType()->isPointerType()) {
BuilderText = BuilderExpr->isImplicitCXXThis()
? "*this"
: llvm::formatv("*{}", BuilderText).str();
}
const StringRef OpType = GetText(CharSourceRange::getTokenRange(
LessToken->getEndLoc(), EndToken->getLastLoc()));
Replace.Replacement = llvm::formatv("{}::create({}{}", OpType, BuilderText,
HasArgs ? ", " : "");
return SmallVector<Edit, 1>({Replace});
};
}
static RewriteRuleWith<std::string> useNewMlirOpBuilderCheckRule() {
Stencil Message = cat("use 'OpType::create(builder, ...)' instead of "
"'builder.create<OpType>(...)'");
// Match a create call on an OpBuilder.
auto BuilderType = cxxRecordDecl(isSameOrDerivedFrom("::mlir::OpBuilder"));
ast_matchers::internal::Matcher<Stmt> Base =
cxxMemberCallExpr(
on(expr(anyOf(hasType(BuilderType), hasType(pointsTo(BuilderType))))
.bind("builder")),
callee(cxxMethodDecl(hasTemplateArgument(0, templateArgument()),
hasName("create"))))
.bind("call");
return applyFirst(
// Attempt rewrite given an lvalue builder, else just warn.
{makeRule(cxxMemberCallExpr(unless(on(cxxTemporaryObjectExpr())), Base),
rewrite(node("call"), node("builder")), Message),
makeRule(Base, noopEdit(node("call")), Message)});
}
UseNewMlirOpBuilderCheck::UseNewMlirOpBuilderCheck(StringRef Name,
ClangTidyContext *Context)
: TransformerClangTidyCheck(useNewMlirOpBuilderCheckRule(), Name, Context) {
}
} // namespace clang::tidy::llvm_check