DEV Community

Cover image for Scaling AI Agents: Distributed Graph Traversal and Choosing the Right Graph Database
Shoaibali Mir
Shoaibali Mir

Posted on

Scaling AI Agents: Distributed Graph Traversal and Choosing the Right Graph Database

In Part 1, I showed why graph traversal speed is the bottleneck.

In Part 2, I covered hybrid LLM-graph planning for single-system agents.

But here's the next problem:

What happens when your infrastructure spans multiple regions, thousands of services, and millions of state transitions?

A single graph database instance can't handle 10M+ nodes with sub-second query times.

Standard A* search becomes a distributed systems problem.

Let's deep dive on how to architect planet-scale autonomous agents that can:

  • Plan across distributed infrastructure graphs
  • Query massive graphs efficiently
  • Handle cross-region dependencies
  • Scale horizontally

The Scale Problem

When Single-Node Graphs Break

Your Kubernetes cluster at scale:

Kubernetes Cluster

The challenge:

Challenge

What you need:

What you need


Architecture: Distributed Graph Traversal

The Three-Layer Approach

3-layer-approach


Part 1: Distributed Graph Traversal

Strategy 1: Graph Partitioning

Partition by service domain:

class GraphPartitioner:
    def __init__(self, partition_strategy='domain'):
        self.strategy = partition_strategy
        self.partitions = {}

    def partition_graph(self, services, dependencies):
        """
        Partition infrastructure graph by domain boundaries
        """
        if self.strategy == 'domain':
            return self._partition_by_domain(services)
        elif self.strategy == 'region':
            return self._partition_by_region(services)
        elif self.strategy == 'hybrid':
            return self._partition_hybrid(services)

    def _partition_by_domain(self, services):
        """
        Group services by business domain
        e.g., payment, auth, search, recommendation
        """
        domains = defaultdict(list)

        for service in services:
            domain = service.domain  # e.g., 'payment', 'auth'
            domains[domain].append(service)

        # Create partition for each domain
        partitions = {}
        for domain, service_list in domains.items():
            partitions[domain] = {
                'services': service_list,
                'internal_edges': self._get_internal_edges(service_list),
                'external_edges': self._get_external_edges(service_list),
                'size': len(service_list)
            }

        return partitions

    def _get_internal_edges(self, services):
        """Edges within the same partition"""
        service_ids = {s.id for s in services}
        return [
            edge for edge in all_edges 
            if edge.source in service_ids and edge.target in service_ids
        ]

    def _get_external_edges(self, services):
        """Edges crossing partition boundaries"""
        service_ids = {s.id for s in services}
        return [
            edge for edge in all_edges 
            if (edge.source in service_ids) != (edge.target in service_ids)
        ]
Enter fullscreen mode Exit fullscreen mode

Result:

Result

Strategy 2: Distributed A* Search

Multi-partition pathfinding:

class DistributedPlanner:
    def __init__(self, partitions, partition_metadata):
        self.partitions = partitions
        self.metadata = partition_metadata
        self.local_planners = {}

        # Create local planner for each partition
        for partition_id, graph_db_url in partitions.items():
            self.local_planners[partition_id] = LocalPlanner(graph_db_url)

    def distributed_search(self, start_state, goal_state):
        """
        Distributed A* search across graph partitions
        """
        # Step 1: Identify which partitions contain start and goal
        start_partition = self._find_partition(start_state)
        goal_partition = self._find_partition(goal_state)

        # Step 2: If same partition, use local search
        if start_partition == goal_partition:
            return self.local_planners[start_partition].search(
                start_state, goal_state
            )

        # Step 3: Find inter-partition path
        partition_path = self._find_partition_path(
            start_partition, goal_partition
        )

        # Step 4: Search within each partition along the path
        full_path = []
        current_state = start_state

        for i in range(len(partition_path) - 1):
            current_partition = partition_path[i]
            next_partition = partition_path[i + 1]

            # Find exit point from current partition
            exit_state = self._find_best_exit_state(
                current_state, current_partition, next_partition
            )

            # Local search within partition
            local_path = self.local_planners[current_partition].search(
                current_state, exit_state
            )

            full_path.extend(local_path)
            current_state = exit_state

        # Final hop to goal
        final_path = self.local_planners[goal_partition].search(
            current_state, goal_state
        )
        full_path.extend(final_path)

        return full_path

    def _find_partition_path(self, start_partition, goal_partition):
        """
        Find shortest path through partitions using partition graph
        """
        # Build partition-level graph
        partition_graph = nx.DiGraph()

        for partition_id, metadata in self.metadata.items():
            partition_graph.add_node(partition_id)

            # Add edges to connected partitions
            for neighbor_id in metadata['neighbors']:
                weight = len(metadata['external_edges'][neighbor_id])
                partition_graph.add_edge(
                    partition_id, neighbor_id, weight=weight
                )

        # Shortest path at partition level
        return nx.shortest_path(
            partition_graph, 
            start_partition, 
            goal_partition,
            weight='weight'
        )

    def _find_best_exit_state(self, current_state, from_partition, to_partition):
        """
        Find the best state to transition between partitions
        Based on cost and likelihood of success
        """
        exit_states = self.metadata[from_partition]['exit_states'][to_partition]

        # Score each potential exit state
        scores = []
        for exit_state in exit_states:
            # Distance from current to exit
            local_cost = self.local_planners[from_partition].estimate_cost(
                current_state, exit_state
            )

            # Historical success rate of this transition
            success_rate = self._get_transition_success_rate(
                from_partition, to_partition, exit_state
            )

            score = local_cost / (success_rate + 0.01)  # Lower is better
            scores.append((exit_state, score))

        # Return exit state with best score
        return min(scores, key=lambda x: x[1])[0]
Enter fullscreen mode Exit fullscreen mode

Strategy 3: Parallel Query Execution

Execute searches in parallel across partitions:

import asyncio
from concurrent.futures import ThreadPoolExecutor

class ParallelGraphQuerier:
    def __init__(self, partition_planners):
        self.planners = partition_planners
        self.executor = ThreadPoolExecutor(max_workers=8)

    async def parallel_search(self, query_plan):
        """
        Execute multiple graph queries in parallel
        """
        tasks = []

        for partition_id, local_query in query_plan.items():
            task = asyncio.create_task(
                self._execute_local_query(partition_id, local_query)
            )
            tasks.append(task)

        # Wait for all queries to complete
        results = await asyncio.gather(*tasks)

        # Merge results
        return self._merge_results(results)

    async def _execute_local_query(self, partition_id, query):
        """
        Execute query against single partition
        """
        loop = asyncio.get_event_loop()

        # Run blocking graph query in thread pool
        result = await loop.run_in_executor(
            self.executor,
            self.planners[partition_id].execute_query,
            query
        )

        return {
            'partition': partition_id,
            'result': result,
            'timestamp': time.time()
        }

    def _merge_results(self, results):
        """
        Combine results from multiple partitions
        """
        merged_paths = []

        # Sort by partition order in overall path
        sorted_results = sorted(results, key=lambda x: x['partition'])

        for result in sorted_results:
            merged_paths.extend(result['result']['path'])

        return merged_paths
Enter fullscreen mode Exit fullscreen mode

Part 2: Choosing the Right Graph Database

The Landscape

Commercial Options:

  • Neo4j - Industry standard, rich ecosystem
  • Amazon Neptune - Managed AWS service
  • TigerGraph - Optimized for analytics
  • Azure Cosmos DB (Gremlin API) - Azure-native

Custom/Open Source:

  • JanusGraph - Distributed, open source
  • ArangoDB - Multi-model database
  • Redis + Custom Logic - Maximum control

Decision Framework

When to use Neo4j:
--> Need rich query language (Cypher)
--> Complex pattern matching required
--> Excellent tooling/ecosystem matters
--> Team familiar with property graphs
--X Budget constraints (<100K nodes: Community Edition works)
--X Cloud vendor lock-in concerns

When to use Neptune:
--> Already on AWS
--> Want fully managed service
--> Need automatic scaling
--> Integration with AWS services critical
--X Team knows Gremlin or willing to learn
--X Cost of managed service acceptable

When to use TigerGraph:
--> Deep graph analytics required
--> Real-time recommendations
--> Need distributed traversal out-of-the-box
--> Handling massive graphs (billions of edges)
--X Smaller community acceptable
--X Learning GSQL is worthwhile

When to build custom (Redis + adjacency lists):
--> Simple shortest path queries dominate
--> Extreme cost sensitivity
--> Need maximum performance for specific patterns
--> Strong engineering team
--X Limited feature requirements
--X Can invest engineering time

Key Evaluation Criteria

1. Query Patterns

Shortest path only --> Custom/Redis
Complex patterns --> Neo4j/Neptune
Analytics --> TigerGraph
Mixed --> Neo4j or hybrid approach
Enter fullscreen mode Exit fullscreen mode

2. Scale

<100K nodes --> Neo4j Community (free)
100K-1M nodes --> Any commercial option
1M-10M nodes --> TigerGraph or distributed Neo4j
10M+ nodes --> TigerGraph or custom distributed
Enter fullscreen mode Exit fullscreen mode

3. Budget

Tight budget -->Custom or Neo4j Community
Moderate --> Neptune or Neo4j Enterprise
Analytics-focused --> TigerGraph
Enter fullscreen mode Exit fullscreen mode

4. Team Expertise

SQL background --> Cypher (Neo4j)
NoSQL background --> Gremlin (Neptune)
Python heavy --> Custom implementation
Enter fullscreen mode Exit fullscreen mode

Database Implementation Patterns

Neo4j Pattern

from neo4j import GraphDatabase

class Neo4jPlanner:
    def __init__(self, uri, user, password):
        self.driver = GraphDatabase.driver(uri, auth=(user, password))

    def shortest_path(self, start_id, end_id):
        with self.driver.session() as session:
            result = session.run("""
                MATCH (start:Service {id: $start_id}),
                      (end:Service {id: $end_id}),
                      path = shortestPath((start)-[:DEPENDS_ON*]-(end))
                RETURN path, length(path) as cost
            """, start_id=start_id, end_id=end_id)

            record = result.single()
            return {
                'path': record['path'],
                'cost': record['cost']
            }

    def get_neighbors(self, node_id, depth=3):
        with self.driver.session() as session:
            result = session.run("""
                MATCH (start:Service {id: $node_id})-[:DEPENDS_ON*1..{depth}]-(neighbor)
                RETURN DISTINCT neighbor.id as id, neighbor.name as name
            """.replace('{depth}', str(depth)), node_id=node_id)

            return [dict(record) for record in result]

    def pattern_match(self, pattern_query):
        """
        Execute custom Cypher pattern
        Example: Find all services depending on auth
        """
        with self.driver.session() as session:
            result = session.run(pattern_query)
            return [dict(record) for record in result]
Enter fullscreen mode Exit fullscreen mode

Pros:

  • Mature, battle-tested
  • Excellent query language (Cypher is SQL-like)
  • Rich visualization tools
  • Strong consistency guarantees

Cons:

  • Memory intensive (plan for 10x node count in RAM)
  • Licensing costs for clustering
  • Vertical scaling limits

Neptune Pattern

from gremlin_python.driver import client, serializer

class NeptunePlanner:
    def __init__(self, endpoint):
        self.client = client.Client(
            f'wss://{endpoint}:8182/gremlin',
            'g',
            message_serializer=serializer.GraphSONSerializersV2d0()
        )

    def shortest_path(self, start_id, end_id):
        query = f"""
        g.V().has('service', 'id', '{start_id}')
         .repeat(out('depends_on').simplePath())
         .until(has('id', '{end_id}'))
         .path()
         .limit(1)
        """

        result = self.client.submit(query).all().result()
        return result[0] if result else None

    def get_neighbors(self, node_id, depth=3):
        query = f"""
        g.V().has('service', 'id', '{node_id}')
         .repeat(out('depends_on'))
         .times({depth})
         .dedup()
         .valueMap()
        """

        return self.client.submit(query).all().result()
Enter fullscreen mode Exit fullscreen mode

Pros:

  • Fully managed (no ops burden)
  • Auto-scaling
  • AWS integration (IAM, VPC, CloudWatch)
  • Multiple APIs (Gremlin, SPARQL)

Cons:

  • Gremlin learning curve
  • Potentially higher cost than self-hosted
  • Less query optimization control

Custom Redis Pattern

import redis
import json

class RedisGraphPlanner:
    def __init__(self, redis_host='localhost'):
        self.redis = redis.Redis(host=redis_host, decode_responses=True)
        self.adjacency_list = defaultdict(list)

    def load_graph(self, nodes, edges):
        """Load graph into Redis"""
        for edge in edges:
            # Store adjacency list
            key = f"neighbors:{edge['source']}"
            neighbor_data = json.dumps({
                'target': edge['target'],
                'cost': edge.get('cost', 1)
            })

            self.redis.rpush(key, neighbor_data)
            self.redis.expire(key, 3600)  # 1 hour TTL

    def shortest_path(self, start_id, end_id):
        """
        Dijkstra's algorithm using Redis for storage
        """
        # Check cache first
        cache_key = f"path:{start_id}:{end_id}"
        cached = self.redis.get(cache_key)

        if cached:
            return json.loads(cached)

        # Compute path
        distances = {start_id: 0}
        previous = {}
        unvisited = {start_id}

        while unvisited:
            current = min(unvisited, key=lambda x: distances.get(x, float('inf')))

            if current == end_id:
                break

            unvisited.remove(current)

            # Get neighbors from Redis
            neighbors_raw = self.redis.lrange(f"neighbors:{current}", 0, -1)
            neighbors = [json.loads(n) for n in neighbors_raw]

            for neighbor in neighbors:
                distance = distances[current] + neighbor['cost']
                target = neighbor['target']

                if distance < distances.get(target, float('inf')):
                    distances[target] = distance
                    previous[target] = current
                    unvisited.add(target)

        # Reconstruct path
        path = self._reconstruct_path(previous, start_id, end_id)

        # Cache result
        self.redis.set(cache_key, json.dumps(path), ex=600)

        return path

    def _reconstruct_path(self, previous, start, end):
        path = []
        current = end

        while current in previous:
            path.append(current)
            current = previous[current]

        path.append(start)
        path.reverse()

        return path
Enter fullscreen mode Exit fullscreen mode

Pros:

  • Extreme performance for simple queries
  • Full control over caching strategy
  • Very cost-effective
  • Scales horizontally easily

Cons:

  • No complex pattern matching
  • More engineering effort
  • Limited built-in features
  • Need to build tooling

Hybrid Architecture: Best of Both Worlds

The Recommended Production Setup

class HybridGraphStore:
    """
    Combine Redis for hot paths + Neo4j for complex queries
    """
    def __init__(self, redis_cluster, neo4j_cluster):
        self.cache = RedisGraphPlanner(redis_cluster)
        self.graph_db = Neo4jPlanner(neo4j_cluster)
        self.hot_path_threshold = 100  # queries/hour

    def shortest_path(self, start, end):
        # Track query frequency
        query_key = f"path:{start}:{end}"
        query_count = int(self.cache.redis.get(f"count:{query_key}") or 0)
        self.cache.redis.incr(f"count:{query_key}")
        self.cache.redis.expire(f"count:{query_key}", 3600)

        # Use Redis for hot paths (frequently queried)
        if query_count > self.hot_path_threshold:
            cached_path = self.cache.shortest_path(start, end)
            if cached_path:
                return cached_path

        # Fall back to Neo4j for rare/complex queries
        path = self.graph_db.shortest_path(start, end)

        # Warm the cache for next time
        cache_key = f"path:{start}:{end}"
        self.cache.redis.set(cache_key, json.dumps(path), ex=3600)

        return path

    def pattern_match(self, pattern):
        """Complex patterns always go to Neo4j"""
        return self.graph_db.pattern_match(pattern)

    def update_state(self, node_id, new_state):
        """Write-through: update both stores"""
        # Invalidate Redis cache for this node
        pattern = f"*:{node_id}:*"
        for key in self.cache.redis.scan_iter(match=pattern):
            self.cache.redis.delete(key)

        # Update Neo4j
        self.graph_db.update_node_property(node_id, 'state', new_state)
Enter fullscreen mode Exit fullscreen mode

Why this works:

  • 80% of queries hit Redis cache (hot paths)
  • 20% use Neo4j (complex/rare queries)
  • Cost: ~60% less than Neo4j-only
  • Performance: Better P95 latency

Performance Expectations

What to Expect from Each Approach

Neo4j (Single Instance):

Graph size: 1M nodes
Shortest path: 50-200ms typical
Complex patterns: 200-500ms typical
Memory: ~10GB for 1M nodes
Enter fullscreen mode Exit fullscreen mode

Neptune:

Graph size: 1M nodes
Shortest path: 100-300ms typical
Complex patterns: 300-800ms typical
Managed service overhead: +20-40% latency vs. self-hosted
Enter fullscreen mode Exit fullscreen mode

Custom Redis:

Graph size: 1M nodes
Shortest path (cached): 5-15ms
Shortest path (uncached): 20-80ms
Cache hit rate: 70-90% typical for production workloads
Enter fullscreen mode Exit fullscreen mode

Hybrid (Redis + Neo4j):

Overall P95: 80-150ms (depending on cache hit rate)
Hot paths (80% of queries): 10-20ms
Cold paths (20% of queries): 150-300ms
Best balance of cost and performance
Enter fullscreen mode Exit fullscreen mode

Cost Considerations

Rough Monthly Cost Estimates

Small Scale (100K nodes, 1K queries/sec):

Neo4j Community (self-hosted): $200-400 (EC2 costs)
Neptune (db.r6g.large): $350-500
Custom Redis: $100-200
Enter fullscreen mode Exit fullscreen mode

Medium Scale (1M nodes, 10K queries/sec):

Neo4j Enterprise (clustered): $1,500-2,500
Neptune (db.r6g.2xlarge): $1,400-2,000
TigerGraph Cloud: $1,200-1,800
Custom Redis Cluster: $400-800
Hybrid (Redis + Neo4j): $800-1,400
Enter fullscreen mode Exit fullscreen mode

Large Scale (10M+ nodes, 50K queries/sec):

Neo4j Enterprise (large cluster): $4,000-8,000
Neptune (db.r6g.8xlarge): $5,000-7,000
TigerGraph: $3,500-6,000
Custom Distributed: $2,000-4,000
Enter fullscreen mode Exit fullscreen mode

Implementation Roadmap

Phase 1: Start Small (Week 1-2)

# Deploy Neo4j Community
docker run -d \
  -p 7474:7474 -p 7687:7687 \
  -v $PWD/data:/data \
  neo4j:latest

# Load your infrastructure graph
python load_graph.py --source infrastructure.json --target neo4j://localhost

# Measure baseline performance
python measure_query_performance.py --queries 1000
Enter fullscreen mode Exit fullscreen mode

Phase 2: Add Caching (Week 3-4)

# Add Redis cache layer
redis_client = redis.Redis(host='localhost')

# Implement cache-aside pattern
def shortest_path_cached(start, end):
    cache_key = f"path:{start}:{end}"
    cached = redis_client.get(cache_key)

    if cached:
        return json.loads(cached)

    # Cache miss - query Neo4j
    path = neo4j_query(start, end)
    redis_client.set(cache_key, json.dumps(path), ex=3600)

    return path
Enter fullscreen mode Exit fullscreen mode

Phase 3: Partition for Scale (Month 2-3)

# Analyze graph for partitioning
partitioner = GraphPartitioner(strategy='domain')
partitions = partitioner.analyze(your_graph)

# Deploy regional instances
for partition in partitions:
    deploy_regional_graph_store(partition)

# Update planner to use distributed search
planner = DistributedPlanner(partitions)
Enter fullscreen mode Exit fullscreen mode

Phase 4: Monitor and Optimize (Ongoing)

# Track key metrics
metrics = {
    'query_latency_p95': monitor_latency(),
    'cache_hit_rate': monitor_cache(),
    'cross_partition_queries': monitor_distribution(),
    'cost_per_query': monitor_cost()
}

# Adjust based on data
if metrics['cache_hit_rate'] < 0.7:
    increase_cache_ttl()

if metrics['cross_partition_queries'] > 0.2:
    rebalance_partitions()
Enter fullscreen mode Exit fullscreen mode

Key Takeaways

Distributed graph traversal:
--> Partition by service domain for natural boundaries
--> Use local planning within partitions (80%+ of queries)
--> Coordinate across partitions only when necessary
--> Cache aggressively (70-90% hit rates achievable)

Database selection:

  • Start with: Neo4j Community (free, feature-rich)
  • Scale to: Hybrid (Redis cache + Neo4j)
  • Large scale: TigerGraph or custom distributed
  • AWS-native: Neptune if already on AWS

Cost optimization:

  • Caching reduces database load by 70-80%
  • Partitioning enables horizontal scaling
  • Hybrid approach: 50-60% cheaper than single database
  • Monitor query patterns to optimize cache strategy

Performance at scale:

  • Well-partitioned graphs: <500ms P95 achievable at 10M nodes
  • Hot path caching: 10-20ms typical
  • Hybrid architecture balances cost and speed

What's Next

Which topic interests you most?

  • Multi-agent coordination (preventing conflicts between autonomous systems)
  • Real-time graph topology updates (handling infrastructure changes mid-planning)
  • Chaos engineering for autonomous systems (testing agent resilience)

Drop a comment with what you'd like to see next.


Try It Yourself

Quick benchmark setup:

  1. Start with Docker Compose:

Docker Compose

  1. Load test data and compare:

Test_Data and Compare

  1. Measure what matters for YOUR workload

References


Hit the ❤️ if this helps you architect planet-scale systems.

Share your infrastructure scale and challenges in the comments.


About the Author

Connect:


Top comments (0)