blob: 6d7d1d6b87c60acb66b3de1001c3901d6a5a78a3 [file] [log] [blame]
//===--- UseConstraintsCheck.cpp - clang-tidy -----------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "UseConstraintsCheck.h"
#include "clang/AST/ASTContext.h"
#include "clang/ASTMatchers/ASTMatchFinder.h"
#include "clang/Lex/Lexer.h"
#include "../utils/LexerUtils.h"
#include <optional>
#include <utility>
using namespace clang::ast_matchers;
namespace clang::tidy::modernize {
struct EnableIfData {
TemplateSpecializationTypeLoc Loc;
TypeLoc Outer;
};
namespace {
AST_MATCHER(FunctionDecl, hasOtherDeclarations) {
auto It = Node.redecls_begin();
auto EndIt = Node.redecls_end();
if (It == EndIt)
return false;
++It;
return It != EndIt;
}
} // namespace
void UseConstraintsCheck::registerMatchers(MatchFinder *Finder) {
Finder->addMatcher(
functionTemplateDecl(
has(functionDecl(unless(hasOtherDeclarations()), isDefinition(),
hasReturnTypeLoc(typeLoc().bind("return")))
.bind("function")))
.bind("functionTemplate"),
this);
}
static std::optional<TemplateSpecializationTypeLoc>
matchEnableIfSpecializationImplTypename(TypeLoc TheType) {
if (const auto Dep = TheType.getAs<DependentNameTypeLoc>()) {
const IdentifierInfo *Identifier = Dep.getTypePtr()->getIdentifier();
if (!Identifier || Identifier->getName() != "type" ||
Dep.getTypePtr()->getKeyword() != ElaboratedTypeKeyword::Typename) {
return std::nullopt;
}
TheType = Dep.getQualifierLoc().getTypeLoc();
}
if (const auto SpecializationLoc =
TheType.getAs<TemplateSpecializationTypeLoc>()) {
const auto *Specialization =
dyn_cast<TemplateSpecializationType>(SpecializationLoc.getTypePtr());
if (!Specialization)
return std::nullopt;
const TemplateDecl *TD =
Specialization->getTemplateName().getAsTemplateDecl();
if (!TD || TD->getName() != "enable_if")
return std::nullopt;
int NumArgs = SpecializationLoc.getNumArgs();
if (NumArgs != 1 && NumArgs != 2)
return std::nullopt;
return SpecializationLoc;
}
return std::nullopt;
}
static std::optional<TemplateSpecializationTypeLoc>
matchEnableIfSpecializationImplTrait(TypeLoc TheType) {
if (const auto Elaborated = TheType.getAs<ElaboratedTypeLoc>())
TheType = Elaborated.getNamedTypeLoc();
if (const auto SpecializationLoc =
TheType.getAs<TemplateSpecializationTypeLoc>()) {
const auto *Specialization =
dyn_cast<TemplateSpecializationType>(SpecializationLoc.getTypePtr());
if (!Specialization)
return std::nullopt;
const TemplateDecl *TD =
Specialization->getTemplateName().getAsTemplateDecl();
if (!TD || TD->getName() != "enable_if_t")
return std::nullopt;
if (!Specialization->isTypeAlias())
return std::nullopt;
if (const auto *AliasedType =
dyn_cast<DependentNameType>(Specialization->getAliasedType())) {
if (AliasedType->getIdentifier()->getName() != "type" ||
AliasedType->getKeyword() != ElaboratedTypeKeyword::Typename) {
return std::nullopt;
}
} else {
return std::nullopt;
}
int NumArgs = SpecializationLoc.getNumArgs();
if (NumArgs != 1 && NumArgs != 2)
return std::nullopt;
return SpecializationLoc;
}
return std::nullopt;
}
static std::optional<TemplateSpecializationTypeLoc>
matchEnableIfSpecializationImpl(TypeLoc TheType) {
if (auto EnableIf = matchEnableIfSpecializationImplTypename(TheType))
return EnableIf;
return matchEnableIfSpecializationImplTrait(TheType);
}
static std::optional<EnableIfData>
matchEnableIfSpecialization(TypeLoc TheType) {
if (const auto Pointer = TheType.getAs<PointerTypeLoc>())
TheType = Pointer.getPointeeLoc();
else if (const auto Reference = TheType.getAs<ReferenceTypeLoc>())
TheType = Reference.getPointeeLoc();
if (const auto Qualified = TheType.getAs<QualifiedTypeLoc>())
TheType = Qualified.getUnqualifiedLoc();
if (auto EnableIf = matchEnableIfSpecializationImpl(TheType))
return EnableIfData{std::move(*EnableIf), TheType};
return std::nullopt;
}
static std::pair<std::optional<EnableIfData>, const Decl *>
matchTrailingTemplateParam(const FunctionTemplateDecl *FunctionTemplate) {
// For non-type trailing param, match very specifically
// 'template <..., enable_if_type<Condition, Type> = Default>' where
// enable_if_type is 'enable_if' or 'enable_if_t'. E.g., 'template <typename
// T, enable_if_t<is_same_v<T, bool>, int*> = nullptr>
//
// Otherwise, match a trailing default type arg.
// E.g., 'template <typename T, typename = enable_if_t<is_same_v<T, bool>>>'
const TemplateParameterList *TemplateParams =
FunctionTemplate->getTemplateParameters();
if (TemplateParams->size() == 0)
return {};
const NamedDecl *LastParam =
TemplateParams->getParam(TemplateParams->size() - 1);
if (const auto *LastTemplateParam =
dyn_cast<NonTypeTemplateParmDecl>(LastParam)) {
if (!LastTemplateParam->hasDefaultArgument() ||
!LastTemplateParam->getName().empty())
return {};
return {matchEnableIfSpecialization(
LastTemplateParam->getTypeSourceInfo()->getTypeLoc()),
LastTemplateParam};
}
if (const auto *LastTemplateParam =
dyn_cast<TemplateTypeParmDecl>(LastParam)) {
if (LastTemplateParam->hasDefaultArgument() &&
LastTemplateParam->getIdentifier() == nullptr) {
return {matchEnableIfSpecialization(
LastTemplateParam->getDefaultArgumentInfo()->getTypeLoc()),
LastTemplateParam};
}
}
return {};
}
template <typename T>
static SourceLocation getRAngleFileLoc(const SourceManager &SM,
const T &Element) {
// getFileLoc handles the case where the RAngle loc is part of a synthesized
// '>>', which ends up allocating a 'scratch space' buffer in the source
// manager.
return SM.getFileLoc(Element.getRAngleLoc());
}
static SourceRange
getConditionRange(ASTContext &Context,
const TemplateSpecializationTypeLoc &EnableIf) {
// TemplateArgumentLoc's SourceRange End is the location of the last token
// (per UnqualifiedId docs). E.g., in `enable_if<AAA && BBB>`, the End
// location will be the first 'B' in 'BBB'.
const LangOptions &LangOpts = Context.getLangOpts();
const SourceManager &SM = Context.getSourceManager();
if (EnableIf.getNumArgs() > 1) {
TemplateArgumentLoc NextArg = EnableIf.getArgLoc(1);
return {EnableIf.getLAngleLoc().getLocWithOffset(1),
utils::lexer::findPreviousTokenKind(
NextArg.getSourceRange().getBegin(), SM, LangOpts, tok::comma)};
}
return {EnableIf.getLAngleLoc().getLocWithOffset(1),
getRAngleFileLoc(SM, EnableIf)};
}
static SourceRange getTypeRange(ASTContext &Context,
const TemplateSpecializationTypeLoc &EnableIf) {
TemplateArgumentLoc Arg = EnableIf.getArgLoc(1);
const LangOptions &LangOpts = Context.getLangOpts();
const SourceManager &SM = Context.getSourceManager();
return {utils::lexer::findPreviousTokenKind(Arg.getSourceRange().getBegin(),
SM, LangOpts, tok::comma)
.getLocWithOffset(1),
getRAngleFileLoc(SM, EnableIf)};
}
// Returns the original source text of the second argument of a call to
// enable_if_t. E.g., in enable_if_t<Condition, TheType>, this function
// returns 'TheType'.
static std::optional<StringRef>
getTypeText(ASTContext &Context,
const TemplateSpecializationTypeLoc &EnableIf) {
if (EnableIf.getNumArgs() > 1) {
const LangOptions &LangOpts = Context.getLangOpts();
const SourceManager &SM = Context.getSourceManager();
bool Invalid = false;
StringRef Text = Lexer::getSourceText(CharSourceRange::getCharRange(
getTypeRange(Context, EnableIf)),
SM, LangOpts, &Invalid)
.trim();
if (Invalid)
return std::nullopt;
return Text;
}
return "void";
}
static std::optional<SourceLocation>
findInsertionForConstraint(const FunctionDecl *Function, ASTContext &Context) {
SourceManager &SM = Context.getSourceManager();
const LangOptions &LangOpts = Context.getLangOpts();
if (const auto *Constructor = dyn_cast<CXXConstructorDecl>(Function)) {
for (const CXXCtorInitializer *Init : Constructor->inits()) {
if (Init->getSourceOrder() == 0)
return utils::lexer::findPreviousTokenKind(Init->getSourceLocation(),
SM, LangOpts, tok::colon);
}
if (Constructor->init_begin() != Constructor->init_end())
return std::nullopt;
}
if (Function->isDeleted()) {
SourceLocation FunctionEnd = Function->getSourceRange().getEnd();
return utils::lexer::findNextAnyTokenKind(FunctionEnd, SM, LangOpts,
tok::equal, tok::equal);
}
const Stmt *Body = Function->getBody();
if (!Body)
return std::nullopt;
return Body->getBeginLoc();
}
bool isPrimaryExpression(const Expr *Expression) {
// This function is an incomplete approximation of checking whether
// an Expr is a primary expression. In particular, if this function
// returns true, the expression is a primary expression. The converse
// is not necessarily true.
if (const auto *Cast = dyn_cast<ImplicitCastExpr>(Expression))
Expression = Cast->getSubExprAsWritten();
if (isa<ParenExpr, DependentScopeDeclRefExpr>(Expression))
return true;
return false;
}
// Return the original source text of an enable_if_t condition, i.e., the
// first template argument). For example, in
// 'enable_if_t<FirstCondition || SecondCondition, AType>', the text
// the text 'FirstCondition || SecondCondition' is returned.
static std::optional<std::string> getConditionText(const Expr *ConditionExpr,
SourceRange ConditionRange,
ASTContext &Context) {
SourceManager &SM = Context.getSourceManager();
const LangOptions &LangOpts = Context.getLangOpts();
SourceLocation PrevTokenLoc = ConditionRange.getEnd();
if (PrevTokenLoc.isInvalid())
return std::nullopt;
const bool SkipComments = false;
Token PrevToken;
std::tie(PrevToken, PrevTokenLoc) = utils::lexer::getPreviousTokenAndStart(
PrevTokenLoc, SM, LangOpts, SkipComments);
bool EndsWithDoubleSlash =
PrevToken.is(tok::comment) &&
Lexer::getSourceText(CharSourceRange::getCharRange(
PrevTokenLoc, PrevTokenLoc.getLocWithOffset(2)),
SM, LangOpts) == "//";
bool Invalid = false;
llvm::StringRef ConditionText = Lexer::getSourceText(
CharSourceRange::getCharRange(ConditionRange), SM, LangOpts, &Invalid);
if (Invalid)
return std::nullopt;
auto AddParens = [&](llvm::StringRef Text) -> std::string {
if (isPrimaryExpression(ConditionExpr))
return Text.str();
return "(" + Text.str() + ")";
};
if (EndsWithDoubleSlash)
return AddParens(ConditionText);
return AddParens(ConditionText.trim());
}
// Handle functions that return enable_if_t, e.g.,
// template <...>
// enable_if_t<Condition, ReturnType> function();
//
// Return a vector of FixItHints if the code can be replaced with
// a C++20 requires clause. In the example above, returns FixItHints
// to result in
// template <...>
// ReturnType function() requires Condition {}
static std::vector<FixItHint> handleReturnType(const FunctionDecl *Function,
const TypeLoc &ReturnType,
const EnableIfData &EnableIf,
ASTContext &Context) {
TemplateArgumentLoc EnableCondition = EnableIf.Loc.getArgLoc(0);
SourceRange ConditionRange = getConditionRange(Context, EnableIf.Loc);
std::optional<std::string> ConditionText = getConditionText(
EnableCondition.getSourceExpression(), ConditionRange, Context);
if (!ConditionText)
return {};
std::optional<StringRef> TypeText = getTypeText(Context, EnableIf.Loc);
if (!TypeText)
return {};
SmallVector<const Expr *, 3> ExistingConstraints;
Function->getAssociatedConstraints(ExistingConstraints);
if (!ExistingConstraints.empty()) {
// FIXME - Support adding new constraints to existing ones. Do we need to
// consider subsumption?
return {};
}
std::optional<SourceLocation> ConstraintInsertionLoc =
findInsertionForConstraint(Function, Context);
if (!ConstraintInsertionLoc)
return {};
std::vector<FixItHint> FixIts;
FixIts.push_back(FixItHint::CreateReplacement(
CharSourceRange::getTokenRange(EnableIf.Outer.getSourceRange()),
*TypeText));
FixIts.push_back(FixItHint::CreateInsertion(
*ConstraintInsertionLoc, "requires " + *ConditionText + " "));
return FixIts;
}
// Handle enable_if_t in a trailing template parameter, e.g.,
// template <..., enable_if_t<Condition, Type> = Type{}>
// ReturnType function();
//
// Return a vector of FixItHints if the code can be replaced with
// a C++20 requires clause. In the example above, returns FixItHints
// to result in
// template <...>
// ReturnType function() requires Condition {}
static std::vector<FixItHint>
handleTrailingTemplateType(const FunctionTemplateDecl *FunctionTemplate,
const FunctionDecl *Function,
const Decl *LastTemplateParam,
const EnableIfData &EnableIf, ASTContext &Context) {
SourceManager &SM = Context.getSourceManager();
const LangOptions &LangOpts = Context.getLangOpts();
TemplateArgumentLoc EnableCondition = EnableIf.Loc.getArgLoc(0);
SourceRange ConditionRange = getConditionRange(Context, EnableIf.Loc);
std::optional<std::string> ConditionText = getConditionText(
EnableCondition.getSourceExpression(), ConditionRange, Context);
if (!ConditionText)
return {};
SmallVector<const Expr *, 3> ExistingConstraints;
Function->getAssociatedConstraints(ExistingConstraints);
if (!ExistingConstraints.empty()) {
// FIXME - Support adding new constraints to existing ones. Do we need to
// consider subsumption?
return {};
}
SourceRange RemovalRange;
const TemplateParameterList *TemplateParams =
FunctionTemplate->getTemplateParameters();
if (!TemplateParams || TemplateParams->size() == 0)
return {};
if (TemplateParams->size() == 1) {
RemovalRange =
SourceRange(TemplateParams->getTemplateLoc(),
getRAngleFileLoc(SM, *TemplateParams).getLocWithOffset(1));
} else {
RemovalRange =
SourceRange(utils::lexer::findPreviousTokenKind(
LastTemplateParam->getSourceRange().getBegin(), SM,
LangOpts, tok::comma),
getRAngleFileLoc(SM, *TemplateParams));
}
std::optional<SourceLocation> ConstraintInsertionLoc =
findInsertionForConstraint(Function, Context);
if (!ConstraintInsertionLoc)
return {};
std::vector<FixItHint> FixIts;
FixIts.push_back(
FixItHint::CreateRemoval(CharSourceRange::getCharRange(RemovalRange)));
FixIts.push_back(FixItHint::CreateInsertion(
*ConstraintInsertionLoc, "requires " + *ConditionText + " "));
return FixIts;
}
void UseConstraintsCheck::check(const MatchFinder::MatchResult &Result) {
const auto *FunctionTemplate =
Result.Nodes.getNodeAs<FunctionTemplateDecl>("functionTemplate");
const auto *Function = Result.Nodes.getNodeAs<FunctionDecl>("function");
const auto *ReturnType = Result.Nodes.getNodeAs<TypeLoc>("return");
if (!FunctionTemplate || !Function || !ReturnType)
return;
// Check for
//
// Case 1. Return type of function
//
// template <...>
// enable_if_t<Condition, ReturnType>::type function() {}
//
// Case 2. Trailing template parameter
//
// template <..., enable_if_t<Condition, Type> = Type{}>
// ReturnType function() {}
//
// or
//
// template <..., typename = enable_if_t<Condition, void>>
// ReturnType function() {}
//
// Case 1. Return type of function
if (auto EnableIf = matchEnableIfSpecialization(*ReturnType)) {
diag(ReturnType->getBeginLoc(),
"use C++20 requires constraints instead of enable_if")
<< handleReturnType(Function, *ReturnType, *EnableIf, *Result.Context);
return;
}
// Case 2. Trailing template parameter
if (auto [EnableIf, LastTemplateParam] =
matchTrailingTemplateParam(FunctionTemplate);
EnableIf && LastTemplateParam) {
diag(LastTemplateParam->getSourceRange().getBegin(),
"use C++20 requires constraints instead of enable_if")
<< handleTrailingTemplateType(FunctionTemplate, Function,
LastTemplateParam, *EnableIf,
*Result.Context);
return;
}
}
} // namespace clang::tidy::modernize