The Abstract Syntax Tree (AST) is one of Python's most powerful yet underappreciated features. Every Python program you write is first parsed into an AST before being compiled to bytecode. Understanding how ASTs work opens the door to building linters, code transformers, static analyzers, and other metaprogramming tools.
In this article, we'll dive deep into Python's ast module and learn how to leverage it for practical static code analysis.
What Is an AST?
An Abstract Syntax Tree is a tree representation of the syntactic structure of source code. Each node in the tree represents a construct in the code — function definitions, variable assignments, loops, expressions, and so on.
Consider this simple function:
def greet(name):
return f"Hello, {name}"
When Python parses this, it creates an AST that looks roughly like:
Module
└── FunctionDef (name='greet')
├── arguments
│ └── arg (arg='name')
└── Return
└── JoinedStr
├── Constant ('Hello, ')
└── FormattedValue
└── Name (id='name')
Exploring the AST Module
Basic Parsing
import ast
code = """
def greet(name):
return f"Hello, {name}"
"""
tree = ast.parse(code)
print(ast.dump(tree, indent=2))
The ast.dump() function gives you a complete string representation of the tree. The indent parameter (Python 3.9+) makes it readable.
Walking the Tree
There are two main ways to traverse an AST:
1. ast.walk() — Simple iteration over all nodes
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef):
print(f"Found function: {node.name}")
elif isinstance(node, ast.Name):
print(f"Found variable: {node.id}")
2. ast.NodeVisitor — Object-oriented traversal
class CodeAnalyzer(ast.NodeVisitor):
def __init__(self):
self.functions = []
self.classes = []
self.imports = []
def visit_FunctionDef(self, node):
self.functions.append({
'name': node.name,
'lineno': node.lineno,
'args': [arg.arg for arg in node.args.args],
})
self.generic_visit(node) # Continue to child nodes
def visit_ClassDef(self, node):
self.classes.append({
'name': node.name,
'lineno': node.lineno,
'methods': [
n.name for n in node.body
if isinstance(n, ast.FunctionDef)
]
})
self.generic_visit(node)
def visit_Import(self, node):
for alias in node.names:
self.imports.append(alias.name)
self.generic_visit(node)
def visit_ImportFrom(self, node):
module = node.module or ''
for alias in node.names:
self.imports.append(f"{module}.{alias.name}")
self.generic_visit(node)
analyzer = CodeAnalyzer()
analyzer.visit(tree)
print(f"Functions: {analyzer.functions}")
print(f"Classes: {analyzer.classes}")
print(f"Imports: {analyzer.imports}")
Practical Example 1: Detect Code Smells
Let's build a static analyzer that detects common Python code smells.
class CodeSmellDetector(ast.NodeVisitor):
def __init__(self):
self.issues = []
def _add_issue(self, lineno, message, severity="warning"):
self.issues.append({
'line': lineno,
'message': message,
'severity': severity
})
def visit_FunctionDef(self, node):
# Check for functions that are too long
if len(node.body) > 20:
self._add_issue(
node.lineno,
f"Function '{node.name}' is too long "
f"({len(node.body)} statements, max 20)",
"warning"
)
# Check for too many parameters
arg_count = len(node.args.args) + len(node.args.kwonlyargs)
if node.args.vararg:
arg_count += 1
if node.args.kwarg:
arg_count += 1
if arg_count > 5:
self._add_issue(
node.lineno,
f"Function '{node.name}' has {arg_count} parameters "
f"(max 5 recommended)",
"warning"
)
self.generic_visit(node)
def visit_ExceptHandler(self, node):
# Bare except clauses
if node.type is None:
self._add_issue(
node.lineno,
"Bare 'except:' clause catches all exceptions, "
"including KeyboardInterrupt and SystemExit",
"error"
)
self.generic_visit(node)
def visit_Compare(self, node):
# Detect mutable default arguments
pass # Handled in visit_FunctionDef
def visit_Call(self, node):
# Detect mutable default arguments in function definitions
pass
def visit_ListComp(self, node):
# Check for overly complex list comprehensions
nested_depth = self._get_comprehension_depth(node)
if nested_depth > 2:
self._add_issue(
node.lineno,
f"Nested comprehension depth {nested_depth} "
f"(max 2 recommended)",
"warning"
)
self.generic_visit(node)
def _get_comprehension_depth(self, node, depth=1):
max_depth = depth
for generator in node.generators:
if isinstance(generator.iter, (ast.ListComp, ast.SetComp,
ast.DictComp, ast.GeneratorExp)):
max_depth = max(
max_depth,
self._get_comprehension_depth(generator.iter, depth + 1)
)
return max_depth
def analyze_code(source: str) -> list:
tree = ast.parse(source)
detector = CodeSmellDetector()
detector.visit(tree)
return detector.issues
Practical Example 2: Finding Mutable Default Arguments
This is one of the most common Python gotchas:
class MutableDefaultDetector(ast.NodeVisitor):
"""Detect mutable default arguments in function definitions."""
MUTABLE_TYPES = (ast.List, ast.Dict, ast.Set, ast.Call)
def __init__(self):
self.violations = []
def visit_FunctionDef(self, node):
for default in node.args.defaults + node.args.kw_defaults:
if default and isinstance(default, self.MUTABLE_TYPES):
self.violations.append({
'function': node.name,
'line': node.lineno,
'type': type(default).__name__
})
self.generic_visit(node)
# Test it
code = """
def add_item(item, items=[]):
items.append(item)
return items
def create_cache(cache={}):
return cache
"""
tree = ast.parse(code)
detector = MutableDefaultDetector()
detector.visit(tree)
for v in detector.violations:
print(f"Line {v['line']}: '{v['function']}' has mutable "
f"default argument (type: {v['type']})")
Practical Example 3: Code Complexity Metrics
class ComplexityAnalyzer(ast.NodeVisitor):
def __init__(self):
self.function_complexity = {}
def _calculate_complexity(self, node):
"""Calculate cyclomatic complexity of a function."""
complexity = 1 # Base complexity
for child in ast.walk(node):
# Each decision point adds 1 to complexity
if isinstance(child, (ast.If, ast.While, ast.For,
ast.ExceptHandler)):
complexity += 1
elif isinstance(child, ast.BoolOp):
# and/or add complexity for each additional operand
complexity += len(child.values) - 1
elif isinstance(child, ast.comprehension):
complexity += 1
if child.ifs:
complexity += len(child.ifs)
return complexity
def visit_FunctionDef(self, node):
complexity = self._calculate_complexity(node)
self.function_complexity[node.name] = {
'complexity': complexity,
'line': node.lineno,
'end_line': node.end_lineno or node.lineno,
}
if complexity > 10:
print(f"WARNING: '{node.name}' has high complexity "
f"({complexity}). Consider refactoring.")
self.generic_visit(node)
For the complete guide with all code examples and advanced patterns, read the full article on our blog.
Originally published at WD Tech Blog. Follow for more Python tutorials, AI tools, and developer resources.
Top comments (0)