| //===----------------------------------------------------------------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "UseNewMLIROpBuilderCheck.h" |
| #include "clang/ASTMatchers/ASTMatchers.h" |
| #include "clang/Basic/LLVM.h" |
| #include "clang/Lex/Lexer.h" |
| #include "clang/Tooling/Transformer/RangeSelector.h" |
| #include "clang/Tooling/Transformer/RewriteRule.h" |
| #include "clang/Tooling/Transformer/SourceCode.h" |
| #include "clang/Tooling/Transformer/Stencil.h" |
| #include "llvm/Support/Error.h" |
| #include "llvm/Support/FormatVariadic.h" |
| |
| namespace clang::tidy::llvm_check { |
| |
| using namespace ::clang::ast_matchers; |
| using namespace ::clang::transformer; |
| |
| static EditGenerator rewrite(RangeSelector Call, RangeSelector Builder) { |
| // This is using an EditGenerator rather than ASTEdit as we want to warn even |
| // if in macro. |
| return [Call = std::move(Call), |
| Builder = std::move(Builder)](const MatchFinder::MatchResult &Result) |
| -> Expected<SmallVector<transformer::Edit, 1>> { |
| Expected<CharSourceRange> CallRange = Call(Result); |
| if (!CallRange) |
| return CallRange.takeError(); |
| SourceManager &SM = *Result.SourceManager; |
| const LangOptions &LangOpts = Result.Context->getLangOpts(); |
| SourceLocation Begin = CallRange->getBegin(); |
| |
| // This will result in just a warning and no edit. |
| bool InMacro = CallRange->getBegin().isMacroID(); |
| if (InMacro) { |
| while (SM.isMacroArgExpansion(Begin)) |
| Begin = SM.getImmediateExpansionRange(Begin).getBegin(); |
| Edit WarnOnly; |
| WarnOnly.Kind = EditKind::Range; |
| WarnOnly.Range = CharSourceRange::getCharRange(Begin, Begin); |
| return SmallVector<Edit, 1>({WarnOnly}); |
| } |
| |
| // This will try to extract the template argument as written so that the |
| // rewritten code looks closest to original. |
| auto NextToken = [&](std::optional<Token> CurrentToken) { |
| if (!CurrentToken) |
| return CurrentToken; |
| if (CurrentToken->is(clang::tok::eof)) |
| return std::optional<Token>(); |
| return clang::Lexer::findNextToken(CurrentToken->getLocation(), SM, |
| LangOpts); |
| }; |
| std::optional<Token> LessToken = |
| clang::Lexer::findNextToken(Begin, SM, LangOpts); |
| while (LessToken && LessToken->getKind() != clang::tok::less) { |
| LessToken = NextToken(LessToken); |
| } |
| if (!LessToken) { |
| return llvm::make_error<llvm::StringError>(llvm::errc::invalid_argument, |
| "missing '<' token"); |
| } |
| |
| std::optional<Token> EndToken = NextToken(LessToken); |
| std::optional<Token> GreaterToken = NextToken(EndToken); |
| for (; GreaterToken && GreaterToken->getKind() != clang::tok::greater; |
| GreaterToken = NextToken(GreaterToken)) { |
| EndToken = GreaterToken; |
| } |
| if (!EndToken) { |
| return llvm::make_error<llvm::StringError>(llvm::errc::invalid_argument, |
| "missing '>' token"); |
| } |
| |
| std::optional<Token> ArgStart = NextToken(GreaterToken); |
| if (!ArgStart || ArgStart->getKind() != clang::tok::l_paren) { |
| return llvm::make_error<llvm::StringError>(llvm::errc::invalid_argument, |
| "missing '(' token"); |
| } |
| std::optional<Token> Arg = NextToken(ArgStart); |
| if (!Arg) { |
| return llvm::make_error<llvm::StringError>(llvm::errc::invalid_argument, |
| "unexpected end of file"); |
| } |
| const bool HasArgs = Arg->getKind() != clang::tok::r_paren; |
| |
| Expected<CharSourceRange> BuilderRange = Builder(Result); |
| if (!BuilderRange) |
| return BuilderRange.takeError(); |
| |
| // Helper for concatting below. |
| auto GetText = [&](const CharSourceRange &Range) { |
| return clang::Lexer::getSourceText(Range, SM, LangOpts); |
| }; |
| |
| Edit Replace; |
| Replace.Kind = EditKind::Range; |
| Replace.Range.setBegin(CallRange->getBegin()); |
| Replace.Range.setEnd(ArgStart->getEndLoc()); |
| const Expr *BuilderExpr = Result.Nodes.getNodeAs<Expr>("builder"); |
| std::string BuilderText = GetText(*BuilderRange).str(); |
| if (BuilderExpr->getType()->isPointerType()) { |
| BuilderText = BuilderExpr->isImplicitCXXThis() |
| ? "*this" |
| : llvm::formatv("*{}", BuilderText).str(); |
| } |
| const StringRef OpType = GetText(CharSourceRange::getTokenRange( |
| LessToken->getEndLoc(), EndToken->getLastLoc())); |
| Replace.Replacement = llvm::formatv("{}::create({}{}", OpType, BuilderText, |
| HasArgs ? ", " : ""); |
| |
| return SmallVector<Edit, 1>({Replace}); |
| }; |
| } |
| |
| static RewriteRuleWith<std::string> useNewMlirOpBuilderCheckRule() { |
| Stencil Message = cat("use 'OpType::create(builder, ...)' instead of " |
| "'builder.create<OpType>(...)'"); |
| // Match a create call on an OpBuilder. |
| auto BuilderType = cxxRecordDecl(isSameOrDerivedFrom("::mlir::OpBuilder")); |
| ast_matchers::internal::Matcher<Stmt> Base = |
| cxxMemberCallExpr( |
| on(expr(anyOf(hasType(BuilderType), hasType(pointsTo(BuilderType)))) |
| .bind("builder")), |
| callee(cxxMethodDecl(hasTemplateArgument(0, templateArgument()), |
| hasName("create")))) |
| .bind("call"); |
| return applyFirst( |
| // Attempt rewrite given an lvalue builder, else just warn. |
| {makeRule(cxxMemberCallExpr(unless(on(cxxTemporaryObjectExpr())), Base), |
| rewrite(node("call"), node("builder")), Message), |
| makeRule(Base, noopEdit(node("call")), Message)}); |
| } |
| |
| UseNewMlirOpBuilderCheck::UseNewMlirOpBuilderCheck(StringRef Name, |
| ClangTidyContext *Context) |
| : TransformerClangTidyCheck(useNewMlirOpBuilderCheckRule(), Name, Context) { |
| } |
| |
| } // namespace clang::tidy::llvm_check |