blob: 8b20ef34c3ff2db51ad62f9d7f193b320bf87c16 [file] [log] [blame]
//===--- Stencil.cpp - Stencil implementation -------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "clang/Tooling/Transformer/Stencil.h"
#include "clang/AST/ASTContext.h"
#include "clang/AST/ASTTypeTraits.h"
#include "clang/AST/Expr.h"
#include "clang/ASTMatchers/ASTMatchFinder.h"
#include "clang/ASTMatchers/ASTMatchers.h"
#include "clang/Basic/SourceLocation.h"
#include "clang/Lex/Lexer.h"
#include "clang/Tooling/Transformer/SourceCode.h"
#include "clang/Tooling/Transformer/SourceCodeBuilders.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/Errc.h"
#include "llvm/Support/Error.h"
#include <atomic>
#include <memory>
#include <string>
using namespace clang;
using namespace transformer;
using ast_matchers::BoundNodes;
using ast_matchers::MatchFinder;
using llvm::errc;
using llvm::Error;
using llvm::Expected;
using llvm::StringError;
static llvm::Expected<DynTypedNode> getNode(const BoundNodes &Nodes,
StringRef Id) {
auto &NodesMap = Nodes.getMap();
auto It = NodesMap.find(Id);
if (It == NodesMap.end())
return llvm::make_error<llvm::StringError>(llvm::errc::invalid_argument,
"Id not bound: " + Id);
return It->second;
}
static Error printNode(StringRef Id, const MatchFinder::MatchResult &Match,
std::string *Result) {
std::string Output;
llvm::raw_string_ostream Os(Output);
auto NodeOrErr = getNode(Match.Nodes, Id);
if (auto Err = NodeOrErr.takeError())
return Err;
NodeOrErr->print(Os, PrintingPolicy(Match.Context->getLangOpts()));
*Result += Os.str();
return Error::success();
}
// FIXME: Consider memoizing this function using the `ASTContext`.
static bool isSmartPointerType(QualType Ty, ASTContext &Context) {
using namespace ::clang::ast_matchers;
// Optimization: hard-code common smart-pointer types. This can/should be
// removed if we start caching the results of this function.
auto KnownSmartPointer =
cxxRecordDecl(hasAnyName("::std::unique_ptr", "::std::shared_ptr"));
const auto QuacksLikeASmartPointer = cxxRecordDecl(
hasMethod(cxxMethodDecl(hasOverloadedOperatorName("->"),
returns(qualType(pointsTo(type()))))),
hasMethod(cxxMethodDecl(hasOverloadedOperatorName("*"),
returns(qualType(references(type()))))));
const auto SmartPointer = qualType(hasDeclaration(
cxxRecordDecl(anyOf(KnownSmartPointer, QuacksLikeASmartPointer))));
return match(SmartPointer, Ty, Context).size() > 0;
}
// Identifies use of `operator*` on smart pointers, and returns the underlying
// smart-pointer expression; otherwise, returns null.
static const Expr *isSmartDereference(const Expr &E, ASTContext &Context) {
using namespace ::clang::ast_matchers;
const auto HasOverloadedArrow = cxxRecordDecl(hasMethod(cxxMethodDecl(
hasOverloadedOperatorName("->"), returns(qualType(pointsTo(type()))))));
// Verify it is a smart pointer by finding `operator->` in the class
// declaration.
auto Deref = cxxOperatorCallExpr(
hasOverloadedOperatorName("*"), hasUnaryOperand(expr().bind("arg")),
callee(cxxMethodDecl(ofClass(HasOverloadedArrow))));
return selectFirst<Expr>("arg", match(Deref, E, Context));
}
namespace {
// An arbitrary fragment of code within a stencil.
class RawTextStencil : public StencilInterface {
std::string Text;
public:
explicit RawTextStencil(std::string T) : Text(std::move(T)) {}
std::string toString() const override {
std::string Result;
llvm::raw_string_ostream OS(Result);
OS << "\"";
OS.write_escaped(Text);
OS << "\"";
OS.flush();
return Result;
}
Error eval(const MatchFinder::MatchResult &Match,
std::string *Result) const override {
Result->append(Text);
return Error::success();
}
};
// A debugging operation to dump the AST for a particular (bound) AST node.
class DebugPrintNodeStencil : public StencilInterface {
std::string Id;
public:
explicit DebugPrintNodeStencil(std::string S) : Id(std::move(S)) {}
std::string toString() const override {
return (llvm::Twine("dPrint(\"") + Id + "\")").str();
}
Error eval(const MatchFinder::MatchResult &Match,
std::string *Result) const override {
return printNode(Id, Match, Result);
}
};
// Operators that take a single node Id as an argument.
enum class UnaryNodeOperator {
Parens,
Deref,
MaybeDeref,
AddressOf,
MaybeAddressOf,
Describe,
};
// Generic container for stencil operations with a (single) node-id argument.
class UnaryOperationStencil : public StencilInterface {
UnaryNodeOperator Op;
std::string Id;
public:
UnaryOperationStencil(UnaryNodeOperator Op, std::string Id)
: Op(Op), Id(std::move(Id)) {}
std::string toString() const override {
StringRef OpName;
switch (Op) {
case UnaryNodeOperator::Parens:
OpName = "expression";
break;
case UnaryNodeOperator::Deref:
OpName = "deref";
break;
case UnaryNodeOperator::MaybeDeref:
OpName = "maybeDeref";
break;
case UnaryNodeOperator::AddressOf:
OpName = "addressOf";
break;
case UnaryNodeOperator::MaybeAddressOf:
OpName = "maybeAddressOf";
break;
case UnaryNodeOperator::Describe:
OpName = "describe";
break;
}
return (OpName + "(\"" + Id + "\")").str();
}
Error eval(const MatchFinder::MatchResult &Match,
std::string *Result) const override {
// The `Describe` operation can be applied to any node, not just
// expressions, so it is handled here, separately.
if (Op == UnaryNodeOperator::Describe)
return printNode(Id, Match, Result);
const auto *E = Match.Nodes.getNodeAs<Expr>(Id);
if (E == nullptr)
return llvm::make_error<StringError>(errc::invalid_argument,
"Id not bound or not Expr: " + Id);
llvm::Optional<std::string> Source;
switch (Op) {
case UnaryNodeOperator::Parens:
Source = tooling::buildParens(*E, *Match.Context);
break;
case UnaryNodeOperator::Deref:
Source = tooling::buildDereference(*E, *Match.Context);
break;
case UnaryNodeOperator::MaybeDeref:
if (E->getType()->isAnyPointerType() ||
isSmartPointerType(E->getType(), *Match.Context)) {
// Strip off any operator->. This can only occur inside an actual arrow
// member access, so we treat it as equivalent to an actual object
// expression.
if (const auto *OpCall = dyn_cast<clang::CXXOperatorCallExpr>(E)) {
if (OpCall->getOperator() == clang::OO_Arrow &&
OpCall->getNumArgs() == 1) {
E = OpCall->getArg(0);
}
}
Source = tooling::buildDereference(*E, *Match.Context);
break;
}
*Result += tooling::getText(*E, *Match.Context);
return Error::success();
case UnaryNodeOperator::AddressOf:
Source = tooling::buildAddressOf(*E, *Match.Context);
break;
case UnaryNodeOperator::MaybeAddressOf:
if (E->getType()->isAnyPointerType() ||
isSmartPointerType(E->getType(), *Match.Context)) {
// Strip off any operator->. This can only occur inside an actual arrow
// member access, so we treat it as equivalent to an actual object
// expression.
if (const auto *OpCall = dyn_cast<clang::CXXOperatorCallExpr>(E)) {
if (OpCall->getOperator() == clang::OO_Arrow &&
OpCall->getNumArgs() == 1) {
E = OpCall->getArg(0);
}
}
*Result += tooling::getText(*E, *Match.Context);
return Error::success();
}
Source = tooling::buildAddressOf(*E, *Match.Context);
break;
case UnaryNodeOperator::Describe:
llvm_unreachable("This case is handled at the start of the function");
}
if (!Source)
return llvm::make_error<StringError>(
errc::invalid_argument,
"Could not construct expression source from ID: " + Id);
*Result += *Source;
return Error::success();
}
};
// The fragment of code corresponding to the selected range.
class SelectorStencil : public StencilInterface {
RangeSelector Selector;
public:
explicit SelectorStencil(RangeSelector S) : Selector(std::move(S)) {}
std::string toString() const override { return "selection(...)"; }
Error eval(const MatchFinder::MatchResult &Match,
std::string *Result) const override {
auto RawRange = Selector(Match);
if (!RawRange)
return RawRange.takeError();
CharSourceRange Range = Lexer::makeFileCharRange(
*RawRange, *Match.SourceManager, Match.Context->getLangOpts());
if (Range.isInvalid()) {
// Validate the original range to attempt to get a meaningful error
// message. If it's valid, then something else is the cause and we just
// return the generic failure message.
if (auto Err =
tooling::validateEditRange(*RawRange, *Match.SourceManager))
return handleErrors(std::move(Err), [](std::unique_ptr<StringError> E) {
assert(E->convertToErrorCode() ==
llvm::make_error_code(errc::invalid_argument) &&
"Validation errors must carry the invalid_argument code");
return llvm::createStringError(
errc::invalid_argument,
"selected range could not be resolved to a valid source range; " +
E->getMessage());
});
return llvm::createStringError(
errc::invalid_argument,
"selected range could not be resolved to a valid source range");
}
// Validate `Range`, because `makeFileCharRange` accepts some ranges that
// `validateEditRange` rejects.
if (auto Err = tooling::validateEditRange(Range, *Match.SourceManager))
return joinErrors(
llvm::createStringError(errc::invalid_argument,
"selected range is not valid for editing"),
std::move(Err));
*Result += tooling::getText(Range, *Match.Context);
return Error::success();
}
};
// A stencil operation to build a member access `e.m` or `e->m`, as appropriate.
class AccessStencil : public StencilInterface {
std::string BaseId;
Stencil Member;
public:
AccessStencil(StringRef BaseId, Stencil Member)
: BaseId(std::string(BaseId)), Member(std::move(Member)) {}
std::string toString() const override {
return (llvm::Twine("access(\"") + BaseId + "\", " + Member->toString() +
")")
.str();
}
Error eval(const MatchFinder::MatchResult &Match,
std::string *Result) const override {
const auto *E = Match.Nodes.getNodeAs<Expr>(BaseId);
if (E == nullptr)
return llvm::make_error<StringError>(errc::invalid_argument,
"Id not bound: " + BaseId);
if (!E->isImplicitCXXThis()) {
llvm::Optional<std::string> S;
if (E->getType()->isAnyPointerType() ||
isSmartPointerType(E->getType(), *Match.Context)) {
// Strip off any operator->. This can only occur inside an actual arrow
// member access, so we treat it as equivalent to an actual object
// expression.
if (const auto *OpCall = dyn_cast<clang::CXXOperatorCallExpr>(E)) {
if (OpCall->getOperator() == clang::OO_Arrow &&
OpCall->getNumArgs() == 1) {
E = OpCall->getArg(0);
}
}
S = tooling::buildArrow(*E, *Match.Context);
} else if (const auto *Operand = isSmartDereference(*E, *Match.Context)) {
// `buildDot` already handles the built-in dereference operator, so we
// only need to catch overloaded `operator*`.
S = tooling::buildArrow(*Operand, *Match.Context);
} else {
S = tooling::buildDot(*E, *Match.Context);
}
if (S.hasValue())
*Result += *S;
else
return llvm::make_error<StringError>(
errc::invalid_argument,
"Could not construct object text from ID: " + BaseId);
}
return Member->eval(Match, Result);
}
};
class IfBoundStencil : public StencilInterface {
std::string Id;
Stencil TrueStencil;
Stencil FalseStencil;
public:
IfBoundStencil(StringRef Id, Stencil TrueStencil, Stencil FalseStencil)
: Id(std::string(Id)), TrueStencil(std::move(TrueStencil)),
FalseStencil(std::move(FalseStencil)) {}
std::string toString() const override {
return (llvm::Twine("ifBound(\"") + Id + "\", " + TrueStencil->toString() +
", " + FalseStencil->toString() + ")")
.str();
}
Error eval(const MatchFinder::MatchResult &Match,
std::string *Result) const override {
auto &M = Match.Nodes.getMap();
return (M.find(Id) != M.end() ? TrueStencil : FalseStencil)
->eval(Match, Result);
}
};
class SelectBoundStencil : public clang::transformer::StencilInterface {
static bool containsNoNullStencils(
const std::vector<std::pair<std::string, Stencil>> &Cases) {
for (const auto &S : Cases)
if (S.second == nullptr)
return false;
return true;
}
public:
SelectBoundStencil(std::vector<std::pair<std::string, Stencil>> Cases,
Stencil Default)
: CaseStencils(std::move(Cases)), DefaultStencil(std::move(Default)) {
assert(containsNoNullStencils(CaseStencils) &&
"cases of selectBound may not be null");
}
~SelectBoundStencil() override{};
llvm::Error eval(const MatchFinder::MatchResult &match,
std::string *result) const override {
const BoundNodes::IDToNodeMap &NodeMap = match.Nodes.getMap();
for (const auto &S : CaseStencils) {
if (NodeMap.count(S.first) > 0) {
return S.second->eval(match, result);
}
}
if (DefaultStencil != nullptr) {
return DefaultStencil->eval(match, result);
}
llvm::SmallVector<llvm::StringRef, 2> CaseIDs;
CaseIDs.reserve(CaseStencils.size());
for (const auto &S : CaseStencils)
CaseIDs.emplace_back(S.first);
return llvm::createStringError(
errc::result_out_of_range,
llvm::Twine("selectBound failed: no cases bound and no default: {") +
llvm::join(CaseIDs, ", ") + "}");
}
std::string toString() const override {
std::string Buffer;
llvm::raw_string_ostream Stream(Buffer);
Stream << "selectBound({";
bool First = true;
for (const auto &S : CaseStencils) {
if (First)
First = false;
else
Stream << "}, ";
Stream << "{\"" << S.first << "\", " << S.second->toString();
}
Stream << "}}";
if (DefaultStencil != nullptr) {
Stream << ", " << DefaultStencil->toString();
}
Stream << ")";
return Stream.str();
}
private:
std::vector<std::pair<std::string, Stencil>> CaseStencils;
Stencil DefaultStencil;
};
class SequenceStencil : public StencilInterface {
std::vector<Stencil> Stencils;
public:
SequenceStencil(std::vector<Stencil> Stencils)
: Stencils(std::move(Stencils)) {}
std::string toString() const override {
llvm::SmallVector<std::string, 2> Parts;
Parts.reserve(Stencils.size());
for (const auto &S : Stencils)
Parts.push_back(S->toString());
return (llvm::Twine("seq(") + llvm::join(Parts, ", ") + ")").str();
}
Error eval(const MatchFinder::MatchResult &Match,
std::string *Result) const override {
for (const auto &S : Stencils)
if (auto Err = S->eval(Match, Result))
return Err;
return Error::success();
}
};
class RunStencil : public StencilInterface {
MatchConsumer<std::string> Consumer;
public:
explicit RunStencil(MatchConsumer<std::string> C) : Consumer(std::move(C)) {}
std::string toString() const override { return "run(...)"; }
Error eval(const MatchFinder::MatchResult &Match,
std::string *Result) const override {
Expected<std::string> Value = Consumer(Match);
if (!Value)
return Value.takeError();
*Result += *Value;
return Error::success();
}
};
} // namespace
Stencil transformer::detail::makeStencil(StringRef Text) {
return std::make_shared<RawTextStencil>(std::string(Text));
}
Stencil transformer::detail::makeStencil(RangeSelector Selector) {
return std::make_shared<SelectorStencil>(std::move(Selector));
}
Stencil transformer::dPrint(StringRef Id) {
return std::make_shared<DebugPrintNodeStencil>(std::string(Id));
}
Stencil transformer::expression(llvm::StringRef Id) {
return std::make_shared<UnaryOperationStencil>(UnaryNodeOperator::Parens,
std::string(Id));
}
Stencil transformer::deref(llvm::StringRef ExprId) {
return std::make_shared<UnaryOperationStencil>(UnaryNodeOperator::Deref,
std::string(ExprId));
}
Stencil transformer::maybeDeref(llvm::StringRef ExprId) {
return std::make_shared<UnaryOperationStencil>(UnaryNodeOperator::MaybeDeref,
std::string(ExprId));
}
Stencil transformer::addressOf(llvm::StringRef ExprId) {
return std::make_shared<UnaryOperationStencil>(UnaryNodeOperator::AddressOf,
std::string(ExprId));
}
Stencil transformer::maybeAddressOf(llvm::StringRef ExprId) {
return std::make_shared<UnaryOperationStencil>(
UnaryNodeOperator::MaybeAddressOf, std::string(ExprId));
}
Stencil transformer::describe(StringRef Id) {
return std::make_shared<UnaryOperationStencil>(UnaryNodeOperator::Describe,
std::string(Id));
}
Stencil transformer::access(StringRef BaseId, Stencil Member) {
return std::make_shared<AccessStencil>(BaseId, std::move(Member));
}
Stencil transformer::ifBound(StringRef Id, Stencil TrueStencil,
Stencil FalseStencil) {
return std::make_shared<IfBoundStencil>(Id, std::move(TrueStencil),
std::move(FalseStencil));
}
Stencil transformer::selectBound(
std::vector<std::pair<std::string, Stencil>> CaseStencils,
Stencil DefaultStencil) {
return std::make_shared<SelectBoundStencil>(std::move(CaseStencils),
std::move(DefaultStencil));
}
Stencil transformer::run(MatchConsumer<std::string> Fn) {
return std::make_shared<RunStencil>(std::move(Fn));
}
Stencil transformer::catVector(std::vector<Stencil> Parts) {
// Only one argument, so don't wrap in sequence.
if (Parts.size() == 1)
return std::move(Parts[0]);
return std::make_shared<SequenceStencil>(std::move(Parts));
}