DEV Community

Cover image for Python's ast Module: Parse, Inspect, and Validate Code at the Token Level
German Yamil
German Yamil

Posted on

Python's ast Module: Parse, Inspect, and Validate Code at the Token Level

Python's ast Module: Parse, Inspect, and Validate Code at the Token Level

ast.parse() is one of the most underused tools in Python. Most developers use it only for "does this code have a syntax error?" and stop there.

The AST gives you full access to the structure of any Python program โ€” every function, every import, every call, every assignment โ€” before running a single line.

Here's what you can do with it.


๐ŸŽ Free: AI Publishing Checklist โ€” 7 steps in Python ยท Full pipeline: germy5.gumroad.com/l/xhxkzz (pay what you want, min $9.99)


Basics: parse and dump

import ast

source = """
def greet(name: str) -> str:
    return f"Hello, {name}!"
"""

tree = ast.parse(source)
print(ast.dump(tree, indent=2))
Enter fullscreen mode Exit fullscreen mode

Output (truncated):

Module(
  body=[
    FunctionDef(
      name='greet',
      args=arguments(
        args=[arg(arg='name', annotation=Name(id='str'))],
        ...
      ),
      body=[Return(value=JoinedStr(...))],
      returns=Name(id='str')
    )
  ]
)
Enter fullscreen mode Exit fullscreen mode

Every Python construct has a corresponding AST node type. The full list is in ast โ€” about 100 node types covering every syntactic element of the language.

Syntax validation (the fast gate)

def validate_syntax(code: str) -> tuple[bool, str]:
    """
    Check if code is syntactically valid Python.
    Returns (is_valid, error_message).
    Fast โ€” no execution, no imports.
    """
    try:
        ast.parse(code)
        return True, ""
    except SyntaxError as e:
        return False, f"Line {e.lineno}: {e.msg}"

ok, msg = validate_syntax("def f(\n    # missing close\n")
print(ok, msg)  # False, 'Line 3: ...'

ok, msg = validate_syntax("x = 1 + 2")
print(ok, msg)  # True, ''
Enter fullscreen mode Exit fullscreen mode

This is what I use as Gate 1 in the ebook pipeline. It runs in microseconds and catches every syntax error before we ever attempt execution.

Walking the tree: ast.walk()

import ast

def extract_imports(code: str) -> list[str]:
    """Extract all imported module names from source code."""
    tree = ast.parse(code)
    imports = []
    for node in ast.walk(tree):
        if isinstance(node, ast.Import):
            for alias in node.names:
                imports.append(alias.name)
        elif isinstance(node, ast.ImportFrom):
            if node.module:
                imports.append(node.module)
    return imports

code = """
import os
import sys
from pathlib import Path
from collections import defaultdict
"""

print(extract_imports(code))
# ['os', 'sys', 'pathlib', 'collections']
Enter fullscreen mode Exit fullscreen mode

Detect third-party imports

This is how I prevent generated code from importing packages that don't exist in the clean subprocess environment:

import ast, sys

STDLIB_MODULES = set(sys.stdlib_module_names)  # Python 3.10+

def find_third_party_imports(code: str) -> list[str]:
    """
    Return list of imported modules that are not in stdlib.
    """
    tree = ast.parse(code)
    third_party = []

    for node in ast.walk(tree):
        if isinstance(node, ast.Import):
            for alias in node.names:
                root = alias.name.split('.')[0]
                if root not in STDLIB_MODULES:
                    third_party.append(alias.name)
        elif isinstance(node, ast.ImportFrom):
            if node.module:
                root = node.module.split('.')[0]
                if root not in STDLIB_MODULES:
                    third_party.append(node.module)

    return third_party

code = """
import os
import pandas as pd
from pathlib import Path
import numpy as np
"""

print(find_third_party_imports(code))
# ['pandas', 'numpy']
Enter fullscreen mode Exit fullscreen mode

Detect dangerous patterns

Before running untrusted code, check for patterns that shouldn't be there:

import ast

DANGEROUS_CALLS = {
    'eval', 'exec', '__import__', 'compile',
    'open',       # file system access
    'subprocess', # shell execution
}

def has_dangerous_calls(code: str) -> list[str]:
    """Return list of dangerous function calls found in code."""
    tree = ast.parse(code)
    found = []

    for node in ast.walk(tree):
        if isinstance(node, ast.Call):
            # Direct call: eval(...)
            if isinstance(node.func, ast.Name):
                if node.func.id in DANGEROUS_CALLS:
                    found.append(node.func.id)
            # Attribute call: os.system(...), subprocess.run(...)
            elif isinstance(node.func, ast.Attribute):
                if node.func.attr in {'system', 'popen', 'run', 'call'}:
                    parent = node.func.value
                    if isinstance(parent, ast.Name):
                        found.append(f"{parent.id}.{node.func.attr}")

    return found

code = """
import os
result = os.system("ls -la")
data = eval(user_input)
"""

print(has_dangerous_calls(code))
# ['os.system', 'eval']
Enter fullscreen mode Exit fullscreen mode

Extract all function names and signatures

import ast

def extract_functions(code: str) -> list[dict]:
    """Extract function names, arguments, and return type annotations."""
    tree = ast.parse(code)
    functions = []

    for node in ast.walk(tree):
        if isinstance(node, ast.FunctionDef):
            args = [arg.arg for arg in node.args.args]
            returns = None
            if node.returns:
                returns = ast.unparse(node.returns)

            docstring = None
            if (node.body and isinstance(node.body[0], ast.Expr)
                    and isinstance(node.body[0].value, ast.Constant)):
                docstring = node.body[0].value.value

            functions.append({
                "name": node.name,
                "args": args,
                "returns": returns,
                "has_docstring": docstring is not None,
                "line": node.lineno,
            })

    return functions

code = """
def add(a: int, b: int) -> int:
    \"\"\"Add two integers.\"\"\"
    return a + b

def greet(name):
    print(f"Hello, {name}")
"""

for fn in extract_functions(code):
    print(fn)
# {'name': 'add', 'args': ['a', 'b'], 'returns': 'int', 'has_docstring': True, 'line': 2}
# {'name': 'greet', 'args': ['name'], 'returns': None, 'has_docstring': False, 'line': 7}
Enter fullscreen mode Exit fullscreen mode

Enforce coding standards before execution

import ast

def enforce_standards(code: str) -> list[str]:
    """
    Check code against mandatory standards.
    Returns list of violations.
    """
    violations = []
    tree = ast.parse(code)

    for node in ast.walk(tree):
        if isinstance(node, ast.FunctionDef):
            # Every function must have a docstring
            has_doc = (
                node.body
                and isinstance(node.body[0], ast.Expr)
                and isinstance(node.body[0].value, ast.Constant)
                and isinstance(node.body[0].value.value, str)
            )
            if not has_doc:
                violations.append(
                    f"Line {node.lineno}: function '{node.name}' has no docstring"
                )

            # Every function must have return type annotation
            if node.returns is None and node.name != '__init__':
                violations.append(
                    f"Line {node.lineno}: function '{node.name}' missing return type annotation"
                )

    return violations

code = """
def add(a: int, b: int) -> int:
    \"\"\"Add two integers.\"\"\"
    return a + b

def no_docstring(x):
    return x * 2
"""

violations = enforce_standards(code)
for v in violations:
    print(v)
# Line 6: function 'no_docstring' has no docstring
# Line 6: function 'no_docstring' missing return type annotation
Enter fullscreen mode Exit fullscreen mode

The full pre-flight check

import ast

def preflight(code: str, allow_third_party: bool = False) -> dict:
    """
    Run all AST checks before executing code.
    Returns dict with 'passed' bool and 'issues' list.
    """
    issues = []

    # Gate 1: syntax
    try:
        tree = ast.parse(code)
    except SyntaxError as e:
        return {"passed": False, "issues": [f"SyntaxError line {e.lineno}: {e.msg}"]}

    # Gate 2: dangerous calls
    dangerous = has_dangerous_calls(code)
    if dangerous:
        issues.append(f"Dangerous calls detected: {dangerous}")

    # Gate 3: third-party imports (optional)
    if not allow_third_party:
        third_party = find_third_party_imports(code)
        if third_party:
            issues.append(f"Third-party imports not allowed: {third_party}")

    # Gate 4: coding standards
    violations = enforce_standards(code)
    issues.extend(violations)

    return {
        "passed": len(issues) == 0,
        "issues": issues,
    }

# Usage
result = preflight("""
import os
def process(data: list) -> list:
    \"\"\"Process a list of items.\"\"\"
    return [x * 2 for x in data]
""")
print(result)
# {'passed': False, 'issues': ["Dangerous calls detected: ['os']"]}
Enter fullscreen mode Exit fullscreen mode

The two-gate system using preflight() + subprocess.run() is the foundation of the ebook validation pipeline: germy5.gumroad.com/l/xhxkzz โ€” pay what you want, min $9.99.


Further Reading

Top comments (0)