DEV Community

Mayuresh
Mayuresh

Posted on

Building a High-Performance Text Embedding API with Rust, Axum, Candle and ONNX

Text embeddings are the backbone of modern AI applications—from semantic search to recommendation systems. In this tutorial, we'll build a production-ready embedding API in Rust that supports two models: a lightweight all-MiniLM-L6-v2 model and Google's EmbeddingGemma.

What We'll Build

A REST API with two endpoints:

  • /embed-mini - Fast embeddings using all-MiniLM-L6-v2 (ONNX)
  • /generate-embedding - Embeddings using EmbeddingGemma (Candle)

Prerequisites

  • Rust installed (1.70+)
  • Basic understanding of async Rust
  • Familiarity with REST APIs

Step 1: Set Up Your Project

Create a new Rust project and add dependencies:

cargo new embedding-api
cd embedding-api
Enter fullscreen mode Exit fullscreen mode

Add these dependencies to your Cargo.toml:

[dependencies]
axum = "0.7"
tokio = { version = "1", features = ["full"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
candle-core = "0.4"
ort = "1.16"
ndarray = "0.15"
tokenizers = "0.15"
reqwest = "0.11"
anyhow = "1.0"
Enter fullscreen mode Exit fullscreen mode

Step 2: Define Your Data Structures

Create a new file src/embeddings.rs and define the request/response types:

use serde::{Deserialize, Serialize};

#[derive(Deserialize)]
pub struct EmbeddingRequest {
    text: String,
}

#[derive(Serialize)]
pub struct EmbeddingResponse {
    embedding: Vec<f32>,
    dimension: usize,
}

#[derive(Serialize)]
struct ErrorResponse {
    error: String,
}
Enter fullscreen mode Exit fullscreen mode

Step 3: Build the ONNX Embedder (MiniLM)

The MiniLM model runs via ONNX Runtime for optimal performance:

use ort::{Environment, ExecutionProvider, Session, SessionBuilder};
use tokenizers::Tokenizer;
use std::sync::Arc;

pub struct Embedder {
    session: Session,
    tokenizer: Arc<Tokenizer>,
}

impl Embedder {
    pub fn new(model_path: &str, tokenizer_path: &str) -> anyhow::Result<Self> {
        println!("Loading ONNX model from: {}", model_path);

        let environment = Environment::builder()
            .with_name("embedder")
            .with_execution_providers([ExecutionProvider::CPU(Default::default())])
            .build()?
            .into_arc();

        let session = SessionBuilder::new(&environment)?
            .with_optimization_level(ort::GraphOptimizationLevel::Level3)?
            .with_intra_threads(4)?
            .with_model_from_file(model_path)?;

        let tokenizer = Tokenizer::from_file(tokenizer_path)
            .map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {}", e))?;

        println!("✅ ONNX model loaded!");

        Ok(Self {
            session,
            tokenizer: Arc::new(tokenizer),
        })
    }
}
Enter fullscreen mode Exit fullscreen mode

Step 4: Implement the Embedding Logic

Add the embedding generation method to the Embedder:

impl Embedder {
    pub fn embedd(&self, text: String) -> anyhow::Result<Vec<f32>> {
        // 1. Tokenize the input
        let encoding = self.tokenizer
            .encode(text, true)
            .map_err(|e| anyhow::anyhow!("Tokenization failed: {}", e))?;

        let ids = encoding.get_ids();
        let mask = encoding.get_attention_mask();
        let seq_len = ids.len();

        // 2. Prepare inputs as i64
        let input_ids: Vec<i64> = ids.iter().map(|&x| x as i64).collect();
        let attention_mask: Vec<i64> = mask.iter().map(|&x| x as i64).collect();
        let token_type_ids = vec![0i64; seq_len];

        // 3. Create 2D arrays
        let input_ids_arr = Array2::from_shape_vec((1, seq_len), input_ids)?;
        let attention_mask_arr = Array2::from_shape_vec((1, seq_len), attention_mask)?;
        let token_type_ids_arr = Array2::from_shape_vec((1, seq_len), token_type_ids)?;

        // 4. Convert to ORT values
        let input_ids_val = ort::Value::from_array(
            self.session.allocator(), 
            &CowArray::from(&input_ids_arr.into_dyn())
        )?;
        let attention_mask_val = ort::Value::from_array(
            self.session.allocator(), 
            &CowArray::from(&attention_mask_arr.into_dyn())
        )?;
        let token_type_ids_val = ort::Value::from_array(
            self.session.allocator(), 
            &CowArray::from(&token_type_ids_arr.into_dyn())
        )?;

        // 5. Run inference
        let outputs = self.session.run(vec![
            input_ids_val, 
            attention_mask_val, 
            token_type_ids_val
        ])?;

        // 6. Extract embeddings
        let embeddings_tensor = outputs[0].try_extract::<f32>()?;
        let embeddings = embeddings_tensor.view();

        // 7. Mean pooling
        let hidden_size = 384;
        let mut pooled = vec![0.0f32; hidden_size];

        for token_idx in 0..seq_len {
            for dim_idx in 0..hidden_size {
                pooled[dim_idx] += embeddings[[0, token_idx, dim_idx]];
            }
        }

        for val in pooled.iter_mut() {
            *val /= seq_len as f32;
        }

        // 8. Normalize to unit vector
        let length: f32 = pooled.iter().map(|x| x * x).sum::<f32>().sqrt();
        let normalized: Vec<f32> = pooled.iter().map(|x| x / length).collect();

        Ok(normalized)
    }
}
Enter fullscreen mode Exit fullscreen mode

Key concepts:

  • Tokenization: Converts text to numerical IDs
  • Mean pooling: Averages token embeddings into a single vector
  • Normalization: Ensures embeddings are comparable via cosine similarity

Step 5: Build the EmbedGemma Handler

For the Candle-based approach using Google's EmbedGemma:

use candle_core::{Device, Tensor, safetensors};

pub async fn load_model() -> anyhow::Result<(HashMap<String, Tensor>, Tokenizer, Device)> {
    let device = Device::Cpu;
    let model_path = std::path::Path::new("models/embeddgemma");

    let tokenizer_file = model_path.join("tokenizer.json");
    let tokenizer = Tokenizer::from_file(&tokenizer_file)
        .map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {}", e))?;

    let model_file = model_path.join("model.safetensors");
    let tensors = safetensors::load(model_file, &device)?;

    println!("Loaded {:?} tensors", tensors.len());
    Ok((tensors, tokenizer, device))
}
Enter fullscreen mode Exit fullscreen mode

Step 6: Create the Embedding Generation Function

pub async fn generate_embedding_internal(
    state: &AppState,
    text: String,
) -> Result<Vec<f32>, String> {
    // Tokenize
    let tokens = state.tokenizer
        .encode(text, true)
        .map_err(|e| format!("Tokenization error: {}", e))?
        .get_ids()
        .to_vec();

    // Get embedding matrix
    let embed_weights = state.tensors
        .get("embed_tokens.weight")
        .ok_or("embed_tokens.weight not found")?;

    // Look up embeddings for each token
    let mut embeddings_vec = Vec::new();
    for &token_id in &tokens {
        let token_tensor = Tensor::new(&[token_id as u32], &state.device)
            .map_err(|e| format!("Failed to create token tensor: {}", e))?;

        let token_embed = embed_weights
            .index_select(&token_tensor, 0)
            .map_err(|e| format!("Embedding lookup error: {}", e))?;

        embeddings_vec.push(token_embed);
    }

    // Stack and pool
    let stacked = Tensor::stack(&embeddings_vec, 0)
        .map_err(|e| format!("Stacking error: {}", e))?;

    let pooled = stacked.mean(0)
        .map_err(|e| format!("Pooling error: {}", e))?;

    // Convert to Vec<f32>
    let embedding_vec: Vec<f32> = pooled
        .squeeze(0)
        .map_err(|e| format!("Squeeze error: {}", e))?
        .to_vec1::<f32>()
        .map_err(|e| format!("Tensor conversion error: {}", e))?;

    // Normalize
    let length: f32 = embedding_vec.iter().map(|x| x * x).sum::<f32>().sqrt();
    let normalized: Vec<f32> = embedding_vec.iter().map(|x| x / length).collect();

    Ok(normalized)
}
Enter fullscreen mode Exit fullscreen mode

Step 7: Create API Handlers

use axum::{Json, extract::State, response::IntoResponse};
use reqwest::StatusCode;

pub async fn embed_mini(
    State(state): State<AppState>,
    Json(request): Json<EmbedRequest>,
) -> impl IntoResponse {
    match state.embedder.embedd(request.text) {
        Ok(embedding) => {
            (
                StatusCode::OK,
                Json(EmbedResponse {
                    embedding,
                    dimension: embedding.len(),
                }),
            ).into_response()
        }
        Err(e) => {
            (
                StatusCode::INTERNAL_SERVER_ERROR,
                Json(ErrorResponse {
                    error: format!("Embedding generation failed: {}", e),
                }),
            ).into_response()
        }
    }
}

pub async fn generate_embedding(
    State(state): State<AppState>,
    Json(request): Json<EmbeddingRequest>,
) -> impl IntoResponse {
    match generate_embedding_internal(&state, request.text).await {
        Ok(embedding) => {
            (
                StatusCode::OK,
                Json(EmbeddingResponse {
                    embedding,
                    dimension: embedding.len(),
                }),
            ).into_response()
        }
        Err(e) => {
            (
                StatusCode::INTERNAL_SERVER_ERROR,
                Json(ErrorResponse {
                    error: format!("Embedding generation failed: {}", e),
                }),
            ).into_response()
        }
    }
}
Enter fullscreen mode Exit fullscreen mode

Step 8: Set Up Your Main Application

In src/main.rs:

use axum::{Router, routing::post};
use std::sync::Arc;

#[derive(Clone)]
pub struct AppState {
    embedder: Arc<Embedder>,
    tensors: HashMap<String, Tensor>,
    tokenizer: Arc<Tokenizer>,
    device: Device,
}

#[tokio::main]
async fn main() -> anyhow::Result<()> {
    // Load models
    let embedder = Embedder::new(
        "models/minilm/model.onnx",
        "models/minilm/tokenizer.json"
    )?;

    let (tensors, tokenizer, device) = load_model().await?;

    let state = AppState {
        embedder: Arc::new(embedder),
        tensors,
        tokenizer: Arc::new(tokenizer),
        device,
    };

    // Create router
    let app = Router::new()
        .route("/embed-mini", post(embed_mini))
        .route("/generate-embedding", post(generate_embedding))
        .with_state(state);

    // Start server
    let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await?;
    println!("🚀 Server running on http://localhost:3000");

    axum::serve(listener, app).await?;
    Ok(())
}
Enter fullscreen mode Exit fullscreen mode

Step 9: Download Your Models

Create a models directory and download:

MiniLM (ONNX):

mkdir -p models/minilm
# Download from Hugging Face: sentence-transformers/all-MiniLM-L6-v2
Enter fullscreen mode Exit fullscreen mode

EmbedGemma:

mkdir -p models/embeddgemma
# Download from Hugging Face: google/embeddinggemma
Enter fullscreen mode Exit fullscreen mode

Step 10: Test Your API

Start the server:

cargo run --release
Enter fullscreen mode Exit fullscreen mode

Test with curl:

# MiniLM endpoint
curl -X POST http://localhost:3000/embed-mini \
  -H "Content-Type: application/json" \
  -d '{"text": "Hello, world!"}'

# EmbedGemma endpoint
curl -X POST http://localhost:3000/generate-embedding \
  -H "Content-Type: application/json" \
  -d '{"text": "Rust is amazing!"}'
Enter fullscreen mode Exit fullscreen mode

Performance Tips

  1. Use release mode for production: cargo build --release
  2. Adjust thread count in ONNX builder based on your CPU
  3. Add caching for frequently requested embeddings
  4. Consider GPU support for larger models using CUDA execution provider

Next Steps

  • Add batch processing support
  • Implement model caching strategies
  • Add metrics and monitoring
  • Deploy with Docker
  • Add authentication

Conclusion

You've built a production-ready embedding API in Rust! This setup gives you:

  • Fast inference with ONNX Runtime
  • Flexible model support (MiniLM, EmbedGemma)
  • Type-safe request handling
  • Easy integration with downstream applications

The normalized embeddings are ready for semantic search, clustering, or any similarity-based tasks.

Connect with me on LinkedIn


Questions? Drop them in the comments below!

Top comments (0)