DEV Community

Harish Kotra (he/him) for Gaia

Posted on • Originally published at hackmd.io

Creating Wikipedia Search Embeddings with a Gaia Node

This tutorial shows how to create embeddings from Wikipedia articles using Gaia nodes. We'll download articles about the 2022 Olympics, split them into sections, and create embeddings using Gaia's OpenAI-compatible API.

Overview

  1. Collect Wikipedia articles about 2022 Olympics
  2. Split documents into searchable chunks
  3. Create embeddings using your own Gaia node
  4. Store results in a CSV file

Prerequisites

Install required packages:

pip install mwclient mwparserfromhell pandas tiktoken openai
Enter fullscreen mode Exit fullscreen mode

Code

Here's the complete code that implements the workflow:

import mwclient
import mwparserfromhell
import os
import pandas as pd
import re
import tiktoken
from openai import OpenAI

# Configure constants
GAIA_NODE_URL = "https://llama8b.gaia.domains/v1/"
CATEGORY_TITLE = "Category:2022 Winter Olympics"
WIKI_SITE = "en.wikipedia.org"
MAX_TOKENS = 1600
BATCH_SIZE = 50

# Initialize OpenAI client with Gaia endpoint 
client = OpenAI(base_url=GAIA_NODE_URL, api_key="not-needed")

SECTIONS_TO_IGNORE = [
    "See also", "References", "External links", "Further reading", 
    "Footnotes", "Bibliography", "Sources", "Citations", "Literature",
    "Notes and references", "Photo gallery", "Works cited", "Photos",
    "Gallery", "Notes", "References and sources", "References and notes",
]

def titles_from_category(category: mwclient.listing.Category, max_depth: int) -> set[str]:
    """Get all page titles from a Wikipedia category and its subcategories."""
    titles = set()
    for cm in category.members():
        if type(cm) == mwclient.page.Page:
            titles.add(cm.name)
        elif isinstance(cm, mwclient.listing.Category) and max_depth > 0:
            deeper_titles = titles_from_category(cm, max_depth=max_depth - 1)
            titles.update(deeper_titles)
    return titles

def all_subsections_from_section(section: mwparserfromhell.wikicode.Wikicode, 
                               parent_titles: list[str], 
                               sections_to_ignore: set[str]) -> list[tuple[list[str], str]]:
    """Extract all subsections from a Wikipedia section."""
    headings = [str(h) for h in section.filter_headings()]
    title = headings[0]
    if title.strip("=" + " ") in sections_to_ignore:
        return []
    titles = parent_titles + [title]
    full_text = str(section)
    section_text = full_text.split(title)[1]
    if len(headings) == 1:
        return [(titles, section_text)]
    else:
        first_subtitle = headings[1]
        section_text = section_text.split(first_subtitle)[0]
        results = [(titles, section_text)]
        for subsection in section.get_sections(levels=[len(titles) + 1]):
            results.extend(all_subsections_from_section(subsection, titles, sections_to_ignore))
        return results

def all_subsections_from_title(title: str,
                             sections_to_ignore: set[str] = SECTIONS_TO_IGNORE,
                             site_name: str = WIKI_SITE) -> list[tuple[list[str], str]]:
    """Get all subsections from a Wikipedia page title."""
    site = mwclient.Site(site_name)
    page = site.pages[title]
    text = page.text()
    parsed_text = mwparserfromhell.parse(text)
    headings = [str(h) for h in parsed_text.filter_headings()]
    if headings:
        summary_text = str(parsed_text).split(headings[0])[0]
    else:
        summary_text = str(parsed_text)
    results = [([title], summary_text)]
    for subsection in parsed_text.get_sections(levels=[2]):
        results.extend(all_subsections_from_section(subsection, [title], sections_to_ignore))
    return results

def clean_section(section: tuple[list[str], str]) -> tuple[list[str], str]:
    """Clean up a Wikipedia section by removing references and whitespace."""
    titles, text = section
    text = re.sub(r"<ref.*?</ref>", "", text)
    text = text.strip()
    return (titles, text)

def keep_section(section: tuple[list[str], str]) -> bool:
    """Determine if a section should be kept based on length."""
    titles, text = section
    return len(text) >= 16

def num_tokens(text: str) -> int:
    """Count the number of tokens in a text."""
    encoding = tiktoken.encoding_for_model("gpt-4")
    return len(encoding.encode(text))

def halved_by_delimiter(string: str, delimiter: str = "\n") -> list[str]:
    """Split a string in two parts at a delimiter, balancing tokens."""
    chunks = string.split(delimiter)
    if len(chunks) == 1:
        return [string, ""]
    elif len(chunks) == 2:
        return chunks
    else:
        total_tokens = num_tokens(string)
        halfway = total_tokens // 2
        best_diff = halfway
        for i, chunk in enumerate(chunks):
            left = delimiter.join(chunks[: i + 1])
            left_tokens = num_tokens(left)
            diff = abs(halfway - left_tokens)
            if diff >= best_diff:
                break
            else:
                best_diff = diff
        left = delimiter.join(chunks[:i])
        right = delimiter.join(chunks[i:])
        return [left, right]

def split_strings_from_subsection(subsection: tuple[list[str], str],
                                max_tokens: int = 1000,
                                max_recursion: int = 5) -> list[str]:
    """Split a subsection into smaller pieces that fit within max_tokens."""
    titles, text = subsection
    string = "\n\n".join(titles + [text])
    if num_tokens(string) <= max_tokens:
        return [string]
    elif max_recursion == 0:
        return [string[:max_tokens]]
    else:
        titles, text = subsection
        for delimiter in ["\n\n", "\n", ". "]:
            left, right = halved_by_delimiter(text, delimiter=delimiter)
            if left == "" or right == "":
                continue
            else:
                results = []
                for half in [left, right]:
                    half_subsection = (titles, half)
                    half_strings = split_strings_from_subsection(
                        half_subsection,
                        max_tokens=max_tokens,
                        max_recursion=max_recursion - 1,
                    )
                    results.extend(half_strings)
                return results
    return [string[:max_tokens]]

def main():
    # 1. Collect Wikipedia articles
    site = mwclient.Site(WIKI_SITE)
    category_page = site.pages[CATEGORY_TITLE]
    titles = titles_from_category(category_page, max_depth=1)
    print(f"Found {len(titles)} articles")

    # 2. Extract and clean sections
    wikipedia_sections = []
    for title in titles:
        wikipedia_sections.extend(all_subsections_from_title(title))
    print(f"Found {len(wikipedia_sections)} sections")

    wikipedia_sections = [clean_section(ws) for ws in wikipedia_sections]
    wikipedia_sections = [ws for ws in wikipedia_sections if keep_section(ws)]

    # 3. Split into chunks
    wikipedia_strings = []
    for section in wikipedia_sections:
        wikipedia_strings.extend(split_strings_from_subsection(section, max_tokens=MAX_TOKENS))
    print(f"Split into {len(wikipedia_strings)} chunks")

    # 4. Get embeddings with retries
    embeddings = []
    for batch_start in range(0, len(wikipedia_strings), BATCH_SIZE):
        batch_end = batch_start + BATCH_SIZE
        batch = wikipedia_strings[batch_start:batch_end]
        print(f"Processing batch {batch_start} to {batch_end-1}")

        max_retries = 3
        for attempt in range(max_retries):
            try:
                response = client.embeddings.create(
                    model="nomic-embed",
                    input=batch
                )
                batch_embeddings = [e.embedding for e in response.data]
                embeddings.extend(batch_embeddings)
                break
            except Exception as e:
                print(f"Error on attempt {attempt + 1}: {str(e)}")
                if attempt == max_retries - 1:
                    raise
                import time
                time.sleep(5 * (attempt + 1))  # Exponential backoff

    # 5. Save to CSV
    df = pd.DataFrame({"text": wikipedia_strings, "embedding": embeddings})
    df.to_csv("winter_olympics_2022.csv", index=False)
    print("Saved embeddings to winter_olympics_2022.csv")

if __name__ == "__main__":
    main()
Enter fullscreen mode Exit fullscreen mode

How It Works

  1. Document Collection: Downloads Wikipedia articles from the "2022 Winter Olympics" category.

  2. Text Processing:

    • Splits articles into sections
    • Removes references and cleans text
    • Excludes short sections and irrelevant parts like "References"
    • Splits long sections to fit token limits
  3. Embedding Generation:

    • Uses Gaia's OpenAI-compatible endpoint
    • Processes text in small batches (50 chunks)
    • Includes retry logic for reliability
    • Uses the nomic-embed model
  4. Storage: Saves text chunks and their embeddings to a CSV file.

Usage

Save the code as wikipedia_embeddings.py and run:

python wikipedia_embeddings.py
Enter fullscreen mode Exit fullscreen mode

The script will create a CSV file containing text chunks and their embeddings, which can be used for semantic search or other applications.

Key Features

  • OpenAI-compatible API usage
  • Robust error handling with retries
  • Efficient text chunking
  • Clean Wikipedia text processing
  • Token-aware splitting

Notes

  • The script uses public Gaia nodes - no API key needed
  • Adjust BATCH_SIZE if you encounter timeouts
  • The embeddings CSV can be large - ensure sufficient disk space
  • The process may take several minutes depending on article count

You can find ready-to-use Gaia nodes at Public Nodes from Gaia.

Credits

Inspired from OpenAI's Cookbook

Results

image

Top comments (0)