clanglibtooling

how to match all return stmt from cxxMethodDecl node


I want to find all the methods which params is reference type, to add some codes before all return stmt.

Here is my code:

Matcher.addMatcher(cxxMethodDecl().bind("r"), &HandlerForReturn);
        const CXXMethodDecl *re = Result.Nodes.getNodeAs<CXXMethodDecl>("r");
        if(sourceManager->isWrittenInMainFile(re->getBeginLoc())) {
            if (re->getNameAsString() == "ChkMemCanUse") {
                for (auto i = 0; i < re->getNumParams(); i++) {
                    auto paramDecl = re->getParamDecl(i);
                    if (paramDecl->getType().getTypePtr()->isReferenceType()) {
                        //TODO
                    }
                }
            }
        }

I'm new to clang. My question is How can I find all return stmt from CXXMethodDecl? or any other solutions?


Solution

  • Clang AST matchers are not really designed to match and more importantly bind a variable number of nodes.

    So, I suggest to maintain your current code finding interesting methods and gathering all of the return statements on your own. It's actually pretty easy with clang::RecursiveASTVisitor template.

    Here is how it can be done:

    class ReturnCollector : public clang::RecursiveASTVisitor<ReturnCollector> {
    public:
      static constexpr auto AVERAGE_NUMBER_OF_RETURNS = 5;
      using Returns = llvm::SmallVector<clang::ReturnStmt *,
                                        AVERAGE_NUMBER_OF_RETURNS>;
    
      static Returns collect(clang::CXXMethodDecl *MD) {
        ReturnCollector ActualCollector;
        ActualCollector.TraverseDecl(MD);
        return ActualCollector.Visited;
      }
    
      bool VisitReturnStmt(clang::ReturnStmt *RS) {
        Visited.push_back(RS);
        return true;
      }
    
    private:
      ReturnCollector() = default;
    
      Returns Visited;
    };
    

    It can be used like this:

    /// clang::CXXMethodDecl *MD
    auto ReturnStmts = ReturnCollector::collect(MD);
    
    llvm::errs() << "Returns of the '" << MD->getName() << "' method:\n";
    for (auto *Return : ReturnStmts) {
      Return->dump();
    }
    llvm::errs() << "\n";
    

    This code applied to the following snippet:

    class A {
      int foo(int x) {
        if (x > 10) {
          if (x < 100) {
            return 20;
          }
          return x + x / 2;
        }
        return 10;
      }
    
      int bar() {
        return 42;
      }
    };
    

    produces this output:

    Returns of the 'foo' method:
    ReturnStmt 0x3e6e6b0
    `-IntegerLiteral 0x3e6e690 'int' 20
    ReturnStmt 0x3e6e7c0
    `-BinaryOperator 0x3e6e7a0 'int' '+'
      |-ImplicitCastExpr 0x3e6e788 'int' <LValueToRValue>
      | `-DeclRefExpr 0x3e6e6f0 'int' lvalue ParmVar 0x3e6e308 'x' 'int'
      `-BinaryOperator 0x3e6e768 'int' '/'
        |-ImplicitCastExpr 0x3e6e750 'int' <LValueToRValue>
        | `-DeclRefExpr 0x3e6e710 'int' lvalue ParmVar 0x3e6e308 'x' 'int'
        `-IntegerLiteral 0x3e6e730 'int' 2
    ReturnStmt 0x3e6e828
    `-IntegerLiteral 0x3e6e808 'int' 10
    
    Returns of the 'bar' method:
    ReturnStmt 0x3e6e878
    `-IntegerLiteral 0x3e6e858 'int' 42
    
    

    I hope this will help you solving your problem!