This article was originally published on AI Study Room. For the full version with working code examples and related articles, visit the original post.
Graph RAG: Knowledge Graphs, Entity Extraction, Relationship Traversal
Introduction
Traditional RAG retrieves documents based on semantic similarity. Graph RAG goes further by modeling the relationships between entities: people, companies, concepts, and their connections. This enables queries like "Which employees worked on projects managed by Alice?" that require traversing relationships rather than matching text. This article covers building knowledge graphs from documents and using them for retrieval.
Entity Extraction
The first step is extracting entities and their relationships from documents:
from pydantic import BaseModel
class Entity(BaseModel):
name: str
type: str
description: str
class Relationship(BaseModel):
source: str
target: str
relationship: str
description: str
class ExtractionResult(BaseModel):
entities: list[Entity]
relationships: list[Relationship]
def extract_graph(documents: list[str]) -> ExtractionResult:
combined_text = "\n\n".join(documents)
response = call_llm_with_structured_output(f"""
Extract all entities and their relationships from the text below.
Entity types to consider: Person, Organization, Technology, Product, Location, Concept, Event
For each entity, provide: name, type, description
For each relationship, provide: source, target, relationship type, description
Text: {combined_text[:8000]}
""", ExtractionResult)
return response
Building the Knowledge Graph
Use a graph database like Neo4j to store and query the extracted structure:
from neo4j import GraphDatabase
class KnowledgeGraph:
def __init__(self, uri: str, user: str, password: str):
self.driver = GraphDatabase.driver(uri, auth=(user, password))
def insert_entities_and_relations(self, extraction: ExtractionResult):
with self.driver.session() as session:
# Create entities
for entity in extraction.entities:
session.run(
"MERGE (e:Entity {name: $name}) "
"SET e.type = $type, e.description = $description",
name=entity.name,
type=entity.type,
description=entity.description,
)
# Create relationships
for rel in extraction.relationships:
session.run(
"MATCH (s:Entity {name: $source}) "
"MATCH (t:Entity {name: $target}) "
"MERGE (s)-[r:RELATES {type: $relationship}]->(t) "
"SET r.description = $description",
source=rel.source,
target=rel.target,
relationship=rel.relationship,
description=rel.description,
)
def traverse(self, start_entity: str, max_depth: int = 2) -> list[dict]:
with self.driver.session() as session:
result = session.run(
"""
MATCH path = (start:Entity {name: $start_entity})-[:RELATES*1..$max_depth]->(related)
RETURN [node in nodes(path) | node.name] AS path_nodes,
[rel in relationships(path) | rel.type] AS path_rels
LIMIT 50
""",
start_entity=start_entity,
max_depth=max_depth,
)
return [record.data() for record in result]
Graph + Vector Hybrid Retrieval
The most powerful pattern combines graph traversal with vector similarity:
class GraphVectorRetriever:
def __init__(self, graph: KnowledgeGraph, vector_store):
self.graph = graph
self.vector_store = vector_store
def retrieve(self, query: str, k: int = 5) -> list[str]:
# Step 1: Identify starting entities from the query
query_entities = self.extract_query_entities(query)
# Step 2: Vector search for broader context
vector_results = self.vector_store.similarity_search(query, k=k)
# Step 3: Graph traversal from identified entities
graph_context = []
for entity in query_entities:
paths = self.graph.traverse(entity, max_depth=2)
for path in paths:
context = " -> ".join(
f"{path['path_nodes'][i]} ({path['path_rels'][i]})"
if i < len(path['path_rels'])
else path['path_nodes'][i]
for i in range(len(path['path_nodes']))
)
graph_context.append(context)
#
Read the full article on AI Study Room for complete code examples, comparison tables, and related resources.
Found this useful? Check out more developer guides and tool comparisons on AI Study Room.
Top comments (0)