blob: b557066d979f50c1b8097574d6281a70133bd31e [file] [log] [blame]
//===--- OverridePureVirtuals.cpp --------------------------------*- C++-*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Tweak to automatically generate stubs for pure virtual methods inherited from
// base classes.
//
// Purpose:
// - Simplifies making a derived class concrete by automating the creation of
// required method overrides from abstract bases.
//
// Tweak Summary:
//
// 1. Activation Conditions (prepare):
// - The tweak activates when the cursor is over a C++ class definition.
// - The class must be abstract (it, or its base classes, have unimplemented
// pure virtual functions).
// - It must also inherit from at least one other abstract class.
//
// 2. Identifying Missing Methods:
// - The tweak scans the inheritance hierarchy of the current class.
// - It identifies all unique pure virtual methods from base classes
// that are not yet implemented or overridden.
// - These missing methods are then grouped by their original access
// specifier (e.g., public, protected).
//
// 3. Code Generation and Insertion:
// - For each group of missing methods, stubs are inserted.
// - If an access specifier section (like `public:`) exists, stubs are
// inserted there; otherwise, a new section is created and appended.
// - Each generated stub includes the `override` keyword, a `// TODO:`
// comment, and a `static_assert(false, ...)` to force a compile-time
// error if the method remains unimplemented.
// - The base method's signature is adjusted (e.g., `virtual` and `= 0`
// are removed for the override).
//
// 4. Code Action Provided:
// - A single code action titled "Override pure virtual methods" is offered.
// - Applying this action results in a single source file modification
// containing all the generated method stubs.
//
// Example:
//
// class Base {
// public:
// virtual void publicMethod() = 0;
// protected:
// virtual auto privateMethod() const -> int = 0;
// };
//
// Before:
// // cursor here
// class Derived : public Base {}^;
//
// After:
//
// class Derived : public Base {
// public:
// void publicMethod() override {
// // TODO: Implement this pure virtual method.
// static_assert(false, "Method `publicMethod` is not implemented.");
// }
//
// protected:
// auto privateMethod() const -> int override {
// // TODO: Implement this pure virtual method.
// static_assert(false, "Method `privateMethod` is not implemented.");
// }
// };
//
//===----------------------------------------------------------------------===//
#include "refactor/Tweak.h"
#include "support/Token.h"
#include "clang/AST/ASTContext.h"
#include "clang/AST/DeclCXX.h"
#include "clang/AST/TypeBase.h"
#include "clang/AST/TypeLoc.h"
#include "clang/Basic/LLVM.h"
#include "clang/Basic/SourceLocation.h"
#include "clang/Tooling/Core/Replacement.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/Support/FormatVariadic.h"
#include <string>
namespace clang {
namespace clangd {
namespace {
// This function removes the "virtual" and the "= 0" at the end;
// e.g.:
// "virtual void foo(int var = 0) = 0" // input.
// "void foo(int var = 0)" // output.
std::string removePureVirtualSyntax(const std::string &MethodDecl,
const LangOptions &LangOpts) {
assert(!MethodDecl.empty());
TokenStream TS = lex(MethodDecl, LangOpts);
std::string DeclString;
for (const clangd::Token &Tk : TS.tokens()) {
if (Tk.Kind == clang::tok::raw_identifier && Tk.text() == "virtual")
continue;
// If the ending two tokens are "= 0", we break here and we already have the
// method's string without the pure virtual syntax.
const auto &Next = Tk.next();
if (Next.next().Kind == tok::eof && Tk.Kind == clang::tok::equal &&
Next.text() == "0")
break;
DeclString += Tk.text();
if (Tk.Kind != tok::l_paren && Next.Kind != tok::comma &&
Next.Kind != tok::r_paren && Next.Kind != tok::l_paren &&
Tk.Kind != tok::coloncolon && Next.Kind != tok::coloncolon)
DeclString += ' ';
}
// Trim the last whitespace.
if (DeclString.back() == ' ')
DeclString.pop_back();
return DeclString;
}
class OverridePureVirtuals final : public Tweak {
public:
const char *id() const final; // defined by REGISTER_TWEAK.
bool prepare(const Selection &Sel) override;
Expected<Effect> apply(const Selection &Sel) override;
std::string title() const override { return "Override pure virtual methods"; }
llvm::StringLiteral kind() const override {
return CodeAction::QUICKFIX_KIND;
}
private:
// Stores the CXXRecordDecl of the class being modified.
const CXXRecordDecl *CurrentDeclDef = nullptr;
// Stores pure virtual methods that need overriding, grouped by their original
// access specifier.
llvm::MapVector<AccessSpecifier, llvm::SmallVector<const CXXMethodDecl *>>
MissingMethodsByAccess;
// Stores the source locations of existing access specifiers in CurrentDecl.
llvm::MapVector<AccessSpecifier, SourceLocation> AccessSpecifierLocations;
// Helper function to gather information before applying the tweak.
void collectMissingPureVirtuals();
};
REGISTER_TWEAK(OverridePureVirtuals)
// Function to get all unique pure virtual methods from the entire
// base class hierarchy of CurrentDeclDef.
llvm::SmallVector<const clang::CXXMethodDecl *>
getAllUniquePureVirtualsFromBaseHierarchy(
const clang::CXXRecordDecl *CurrentDeclDef) {
llvm::SmallVector<const clang::CXXMethodDecl *> AllPureVirtualsInHierarchy;
llvm::DenseSet<const clang::CXXMethodDecl *> CanonicalPureVirtualsSeen;
if (!CurrentDeclDef || !CurrentDeclDef->getDefinition())
return AllPureVirtualsInHierarchy;
const clang::CXXRecordDecl *Def = CurrentDeclDef->getDefinition();
Def->forallBases([&](const clang::CXXRecordDecl *BaseDefinition) {
for (const clang::CXXMethodDecl *Method : BaseDefinition->methods()) {
if (Method->isPureVirtual() &&
CanonicalPureVirtualsSeen.insert(Method->getCanonicalDecl()).second)
AllPureVirtualsInHierarchy.emplace_back(Method);
}
// Continue iterating through all bases.
return true;
});
return AllPureVirtualsInHierarchy;
}
// Gets canonical declarations of methods already overridden or implemented in
// class D.
llvm::SetVector<const CXXMethodDecl *>
getImplementedOrOverriddenCanonicals(const CXXRecordDecl *D) {
llvm::SetVector<const CXXMethodDecl *> ImplementedSet;
for (const CXXMethodDecl *M : D->methods()) {
// If M provides an implementation for any virtual method it overrides.
// A method is an "implementation" if it's virtual and not pure.
// Or if it directly overrides a base method.
for (const CXXMethodDecl *OverriddenM : M->overridden_methods())
ImplementedSet.insert(OverriddenM->getCanonicalDecl());
}
return ImplementedSet;
}
// Get the location of every colon of the `AccessSpecifier`.
llvm::MapVector<AccessSpecifier, SourceLocation>
getSpecifierLocations(const CXXRecordDecl *D) {
llvm::MapVector<AccessSpecifier, SourceLocation> Locs;
for (auto *DeclNode : D->decls()) {
if (const auto *ASD = llvm::dyn_cast<AccessSpecDecl>(DeclNode))
Locs[ASD->getAccess()] = ASD->getColonLoc();
}
return Locs;
}
bool hasAbstractBaseAncestor(const clang::CXXRecordDecl *CurrentDecl) {
assert(CurrentDecl && CurrentDecl->getDefinition());
return llvm::any_of(
CurrentDecl->getDefinition()->bases(), [](CXXBaseSpecifier BaseSpec) {
const auto *D = BaseSpec.getType()->getAsCXXRecordDecl();
const auto *Def = D ? D->getDefinition() : nullptr;
return Def && Def->isAbstract();
});
}
// The tweak is available if the selection is over an abstract C++ class
// definition that also inherits from at least one other abstract class.
bool OverridePureVirtuals::prepare(const Selection &Sel) {
const SelectionTree::Node *Node = Sel.ASTSelection.commonAncestor();
if (!Node)
return false;
// Make sure we have a definition.
CurrentDeclDef = Node->ASTNode.get<CXXRecordDecl>();
if (!CurrentDeclDef || !CurrentDeclDef->getDefinition())
return false;
// From now on, we should work with the definition.
CurrentDeclDef = CurrentDeclDef->getDefinition();
// Only offer for abstract classes with abstract bases.
return CurrentDeclDef->isAbstract() &&
hasAbstractBaseAncestor(CurrentDeclDef);
}
// Collects all pure virtual methods from base classes that `CurrentDeclDef` has
// not yet overridden, grouped by their original access specifier.
//
// Results are stored in `MissingMethodsByAccess` and `AccessSpecifierLocations`
// is also populated.
void OverridePureVirtuals::collectMissingPureVirtuals() {
if (!CurrentDeclDef)
return;
AccessSpecifierLocations = getSpecifierLocations(CurrentDeclDef);
MissingMethodsByAccess.clear();
// Get all unique pure virtual methods from the entire base class hierarchy.
llvm::SmallVector<const CXXMethodDecl *> AllPureVirtualsInHierarchy =
getAllUniquePureVirtualsFromBaseHierarchy(CurrentDeclDef);
// Get methods already implemented or overridden in CurrentDecl.
const auto ImplementedOrOverriddenSet =
getImplementedOrOverriddenCanonicals(CurrentDeclDef);
// Filter AllPureVirtualsInHierarchy to find those not in
// ImplementedOrOverriddenSet, which needs to be overriden.
for (const CXXMethodDecl *BaseMethod : AllPureVirtualsInHierarchy) {
bool AlreadyHandled = ImplementedOrOverriddenSet.contains(BaseMethod);
if (!AlreadyHandled)
MissingMethodsByAccess[BaseMethod->getAccess()].emplace_back(BaseMethod);
}
}
std::string generateOverrideString(const CXXMethodDecl *Method,
const LangOptions &LangOpts) {
std::string MethodDecl;
auto OS = llvm::raw_string_ostream(MethodDecl);
Method->print(OS);
return llvm::formatv(
"\n {0} override {{\n"
" // TODO: Implement this pure virtual method.\n"
" static_assert(false, \"Method `{1}` is not implemented.\");\n"
" }",
removePureVirtualSyntax(MethodDecl, LangOpts), Method->getName())
.str();
}
// Free function to generate the string for a group of method overrides.
std::string generateOverridesStringForGroup(
llvm::SmallVector<const CXXMethodDecl *> Methods,
const LangOptions &LangOpts) {
llvm::SmallVector<std::string> MethodsString;
MethodsString.reserve(Methods.size());
for (const CXXMethodDecl *Method : Methods) {
MethodsString.emplace_back(generateOverrideString(Method, LangOpts));
}
return llvm::join(MethodsString, "\n") + '\n';
}
Expected<Tweak::Effect> OverridePureVirtuals::apply(const Selection &Sel) {
// The correctness of this tweak heavily relies on the accurate population of
// these members.
collectMissingPureVirtuals();
// The `prepare` should prevent this. If the prepare identifies an abstract
// method, then is must have missing methods.
assert(!MissingMethodsByAccess.empty());
const auto &SM = Sel.AST->getSourceManager();
const auto &LangOpts = Sel.AST->getLangOpts();
tooling::Replacements EditReplacements;
// Stores text for new access specifier sections that are not already present
// in the class.
// Example:
// public: // ...
// protected: // ...
std::string NewSectionsToAppendText;
for (const auto &[AS, Methods] : MissingMethodsByAccess) {
assert(!Methods.empty());
std::string MethodsGroupString =
generateOverridesStringForGroup(Methods, LangOpts);
auto *ExistingSpecLocIter = AccessSpecifierLocations.find(AS);
bool ASExists = ExistingSpecLocIter != AccessSpecifierLocations.end();
if (ASExists) {
// Access specifier section already exists in the class.
// Get location immediately *after* the colon.
SourceLocation InsertLoc =
ExistingSpecLocIter->second.getLocWithOffset(1);
// Create a replacement to insert the method declarations.
// The replacement is at InsertLoc, has length 0 (insertion), and uses
// InsertionText.
std::string InsertionText = MethodsGroupString;
tooling::Replacement Rep(SM, InsertLoc, 0, InsertionText);
if (auto Err = EditReplacements.add(Rep))
return llvm::Expected<Tweak::Effect>(std::move(Err));
} else {
// Access specifier section does not exist in the class.
// These methods will be grouped into NewSectionsToAppendText and added
// towards the end of the class definition.
NewSectionsToAppendText +=
getAccessSpelling(AS).str() + ':' + MethodsGroupString;
}
}
// After processing all access specifiers, add any newly created sections
// (stored in NewSectionsToAppendText) to the end of the class.
if (!NewSectionsToAppendText.empty()) {
// AppendLoc is the SourceLocation of the closing brace '}' of the class.
// The replacement will insert text *before* this closing brace.
SourceLocation AppendLoc = CurrentDeclDef->getBraceRange().getEnd();
std::string FinalAppendText = std::move(NewSectionsToAppendText);
if (!CurrentDeclDef->decls_empty() || !EditReplacements.empty()) {
FinalAppendText = '\n' + FinalAppendText;
}
// Create a replacement to append the new sections.
tooling::Replacement Rep(SM, AppendLoc, 0, FinalAppendText);
if (auto Err = EditReplacements.add(Rep))
return llvm::Expected<Tweak::Effect>(std::move(Err));
}
if (EditReplacements.empty()) {
return llvm::make_error<llvm::StringError>(
"No changes to apply (internal error or no methods generated).",
llvm::inconvertibleErrorCode());
}
// Return the collected replacements as the effect of this tweak.
return Effect::mainFileEdit(SM, EditReplacements);
}
} // namespace
} // namespace clangd
} // namespace clang