DEV Community

Cover image for Understanding Python AST for Static Code Analysis
WDSEGA
WDSEGA

Posted on

Understanding Python AST for Static Code Analysis

The Abstract Syntax Tree (AST) is one of Python's most powerful yet underappreciated features. Every Python program you write is first parsed into an AST before being compiled to bytecode. Understanding how ASTs work opens the door to building linters, code transformers, static analyzers, and other metaprogramming tools.

In this article, we'll dive deep into Python's ast module and learn how to leverage it for practical static code analysis.

What Is an AST?

An Abstract Syntax Tree is a tree representation of the syntactic structure of source code. Each node in the tree represents a construct in the code — function definitions, variable assignments, loops, expressions, and so on.

Consider this simple function:

def greet(name):
    return f"Hello, {name}"
Enter fullscreen mode Exit fullscreen mode

When Python parses this, it creates an AST that looks roughly like:

Module
 └── FunctionDef (name='greet')
      ├── arguments
      │    └── arg (arg='name')
      └── Return
           └── JoinedStr
                ├── Constant ('Hello, ')
                └── FormattedValue
                     └── Name (id='name')
Enter fullscreen mode Exit fullscreen mode

Exploring the AST Module

Basic Parsing

import ast

code = """
def greet(name):
    return f"Hello, {name}"
"""

tree = ast.parse(code)
print(ast.dump(tree, indent=2))
Enter fullscreen mode Exit fullscreen mode

The ast.dump() function gives you a complete string representation of the tree. The indent parameter (Python 3.9+) makes it readable.

Walking the Tree

There are two main ways to traverse an AST:

1. ast.walk() — Simple iteration over all nodes

for node in ast.walk(tree):
    if isinstance(node, ast.FunctionDef):
        print(f"Found function: {node.name}")
    elif isinstance(node, ast.Name):
        print(f"Found variable: {node.id}")
Enter fullscreen mode Exit fullscreen mode

2. ast.NodeVisitor — Object-oriented traversal

class CodeAnalyzer(ast.NodeVisitor):
    def __init__(self):
        self.functions = []
        self.classes = []
        self.imports = []

    def visit_FunctionDef(self, node):
        self.functions.append({
            'name': node.name,
            'lineno': node.lineno,
            'args': [arg.arg for arg in node.args.args],
        })
        self.generic_visit(node)  # Continue to child nodes

    def visit_ClassDef(self, node):
        self.classes.append({
            'name': node.name,
            'lineno': node.lineno,
            'methods': [
                n.name for n in node.body
                if isinstance(n, ast.FunctionDef)
            ]
        })
        self.generic_visit(node)

    def visit_Import(self, node):
        for alias in node.names:
            self.imports.append(alias.name)
        self.generic_visit(node)

    def visit_ImportFrom(self, node):
        module = node.module or ''
        for alias in node.names:
            self.imports.append(f"{module}.{alias.name}")
        self.generic_visit(node)

analyzer = CodeAnalyzer()
analyzer.visit(tree)
print(f"Functions: {analyzer.functions}")
print(f"Classes: {analyzer.classes}")
print(f"Imports: {analyzer.imports}")
Enter fullscreen mode Exit fullscreen mode

Practical Example 1: Detect Code Smells

Let's build a static analyzer that detects common Python code smells.

class CodeSmellDetector(ast.NodeVisitor):
    def __init__(self):
        self.issues = []

    def _add_issue(self, lineno, message, severity="warning"):
        self.issues.append({
            'line': lineno,
            'message': message,
            'severity': severity
        })

    def visit_FunctionDef(self, node):
        # Check for functions that are too long
        if len(node.body) > 20:
            self._add_issue(
                node.lineno,
                f"Function '{node.name}' is too long "
                f"({len(node.body)} statements, max 20)",
                "warning"
            )

        # Check for too many parameters
        arg_count = len(node.args.args) + len(node.args.kwonlyargs)
        if node.args.vararg:
            arg_count += 1
        if node.args.kwarg:
            arg_count += 1
        if arg_count > 5:
            self._add_issue(
                node.lineno,
                f"Function '{node.name}' has {arg_count} parameters "
                f"(max 5 recommended)",
                "warning"
            )

        self.generic_visit(node)

    def visit_ExceptHandler(self, node):
        # Bare except clauses
        if node.type is None:
            self._add_issue(
                node.lineno,
                "Bare 'except:' clause catches all exceptions, "
                "including KeyboardInterrupt and SystemExit",
                "error"
            )
        self.generic_visit(node)

    def visit_Compare(self, node):
        # Detect mutable default arguments
        pass  # Handled in visit_FunctionDef

    def visit_Call(self, node):
        # Detect mutable default arguments in function definitions
        pass

    def visit_ListComp(self, node):
        # Check for overly complex list comprehensions
        nested_depth = self._get_comprehension_depth(node)
        if nested_depth > 2:
            self._add_issue(
                node.lineno,
                f"Nested comprehension depth {nested_depth} "
                f"(max 2 recommended)",
                "warning"
            )
        self.generic_visit(node)

    def _get_comprehension_depth(self, node, depth=1):
        max_depth = depth
        for generator in node.generators:
            if isinstance(generator.iter, (ast.ListComp, ast.SetComp,
                                           ast.DictComp, ast.GeneratorExp)):
                max_depth = max(
                    max_depth,
                    self._get_comprehension_depth(generator.iter, depth + 1)
                )
        return max_depth


def analyze_code(source: str) -> list:
    tree = ast.parse(source)
    detector = CodeSmellDetector()
    detector.visit(tree)
    return detector.issues
Enter fullscreen mode Exit fullscreen mode

Practical Example 2: Finding Mutable Default Arguments

This is one of the most common Python gotchas:

class MutableDefaultDetector(ast.NodeVisitor):
    """Detect mutable default arguments in function definitions."""

    MUTABLE_TYPES = (ast.List, ast.Dict, ast.Set, ast.Call)

    def __init__(self):
        self.violations = []

    def visit_FunctionDef(self, node):
        for default in node.args.defaults + node.args.kw_defaults:
            if default and isinstance(default, self.MUTABLE_TYPES):
                self.violations.append({
                    'function': node.name,
                    'line': node.lineno,
                    'type': type(default).__name__
                })
        self.generic_visit(node)


# Test it
code = """
def add_item(item, items=[]):
    items.append(item)
    return items

def create_cache(cache={}):
    return cache
"""

tree = ast.parse(code)
detector = MutableDefaultDetector()
detector.visit(tree)
for v in detector.violations:
    print(f"Line {v['line']}: '{v['function']}' has mutable "
          f"default argument (type: {v['type']})")
Enter fullscreen mode Exit fullscreen mode

Practical Example 3: Code Complexity Metrics

class ComplexityAnalyzer(ast.NodeVisitor):
    def __init__(self):
        self.function_complexity = {}

    def _calculate_complexity(self, node):
        """Calculate cyclomatic complexity of a function."""
        complexity = 1  # Base complexity

        for child in ast.walk(node):
            # Each decision point adds 1 to complexity
            if isinstance(child, (ast.If, ast.While, ast.For,
                                  ast.ExceptHandler)):
                complexity += 1
            elif isinstance(child, ast.BoolOp):
                # and/or add complexity for each additional operand
                complexity += len(child.values) - 1
            elif isinstance(child, ast.comprehension):
                complexity += 1
                if child.ifs:
                    complexity += len(child.ifs)

        return complexity

    def visit_FunctionDef(self, node):
        complexity = self._calculate_complexity(node)
        self.function_complexity[node.name] = {
            'complexity': complexity,
            'line': node.lineno,
            'end_line': node.end_lineno or node.lineno,
        }

        if complexity > 10:
            print(f"WARNING: '{node.name}' has high complexity "
                  f"({complexity}). Consider refactoring.")

        self.generic_visit(node)
Enter fullscreen mode Exit fullscreen mode

For the complete guide with all code examples and advanced patterns, read the full article on our blog.


Originally published at WD Tech Blog. Follow for more Python tutorials, AI tools, and developer resources.

Top comments (0)