DEV Community

丁久
丁久

Posted on • Originally published at dingjiu1989-hue.github.io

Graph RAG: Knowledge Graphs, Entity Extraction, Relationship Traversal

This article was originally published on AI Study Room. For the full version with working code examples and related articles, visit the original post.

Graph RAG: Knowledge Graphs, Entity Extraction, Relationship Traversal

Introduction

Traditional RAG retrieves documents based on semantic similarity. Graph RAG goes further by modeling the relationships between entities: people, companies, concepts, and their connections. This enables queries like "Which employees worked on projects managed by Alice?" that require traversing relationships rather than matching text. This article covers building knowledge graphs from documents and using them for retrieval.

Entity Extraction

The first step is extracting entities and their relationships from documents:

from pydantic import BaseModel

class Entity(BaseModel):

    name: str

    type: str

    description: str

class Relationship(BaseModel):

    source: str

    target: str

    relationship: str

    description: str

class ExtractionResult(BaseModel):

    entities: list[Entity]

    relationships: list[Relationship]

def extract_graph(documents: list[str]) -> ExtractionResult:

    combined_text = "\n\n".join(documents)

    response = call_llm_with_structured_output(f"""

    Extract all entities and their relationships from the text below.

    Entity types to consider: Person, Organization, Technology, Product, Location, Concept, Event

    For each entity, provide: name, type, description

    For each relationship, provide: source, target, relationship type, description

    Text: {combined_text[:8000]}

    """, ExtractionResult)

    return response
Enter fullscreen mode Exit fullscreen mode

Building the Knowledge Graph

Use a graph database like Neo4j to store and query the extracted structure:

from neo4j import GraphDatabase

class KnowledgeGraph:

    def __init__(self, uri: str, user: str, password: str):

        self.driver = GraphDatabase.driver(uri, auth=(user, password))

    def insert_entities_and_relations(self, extraction: ExtractionResult):

        with self.driver.session() as session:

            # Create entities

            for entity in extraction.entities:

                session.run(

                    "MERGE (e:Entity {name: $name}) "

                    "SET e.type = $type, e.description = $description",

                    name=entity.name,

                    type=entity.type,

                    description=entity.description,

                )

            # Create relationships

            for rel in extraction.relationships:

                session.run(

                    "MATCH (s:Entity {name: $source}) "

                    "MATCH (t:Entity {name: $target}) "

                    "MERGE (s)-[r:RELATES {type: $relationship}]->(t) "

                    "SET r.description = $description",

                    source=rel.source,

                    target=rel.target,

                    relationship=rel.relationship,

                    description=rel.description,

                )

    def traverse(self, start_entity: str, max_depth: int = 2) -> list[dict]:

        with self.driver.session() as session:

            result = session.run(

                """

                MATCH path = (start:Entity {name: $start_entity})-[:RELATES*1..$max_depth]->(related)

                RETURN [node in nodes(path) | node.name] AS path_nodes,

                       [rel in relationships(path) | rel.type] AS path_rels

                LIMIT 50

                """,

                start_entity=start_entity,

                max_depth=max_depth,

            )

            return [record.data() for record in result]
Enter fullscreen mode Exit fullscreen mode

Graph + Vector Hybrid Retrieval

The most powerful pattern combines graph traversal with vector similarity:

class GraphVectorRetriever:

    def __init__(self, graph: KnowledgeGraph, vector_store):

        self.graph = graph

        self.vector_store = vector_store

    def retrieve(self, query: str, k: int = 5) -> list[str]:

        # Step 1: Identify starting entities from the query

        query_entities = self.extract_query_entities(query)

        # Step 2: Vector search for broader context

        vector_results = self.vector_store.similarity_search(query, k=k)

        # Step 3: Graph traversal from identified entities

        graph_context = []

        for entity in query_entities:

            paths = self.graph.traverse(entity, max_depth=2)

            for path in paths:

                context = " -> ".join(

                    f"{path['path_nodes'][i]} ({path['path_rels'][i]})"

                    if i < len(path['path_rels'])

                    else path['path_nodes'][i]

                    for i in range(len(path['path_nodes']))

                )

                graph_context.append(context)

        # 
Enter fullscreen mode Exit fullscreen mode

Read the full article on AI Study Room for complete code examples, comparison tables, and related resources.

Found this useful? Check out more developer guides and tool comparisons on AI Study Room.

Top comments (0)