Text embeddings are the backbone of modern AI applications—from semantic search to recommendation systems. In this tutorial, we'll build a production-ready embedding API in Rust that supports two models: a lightweight all-MiniLM-L6-v2 model and Google's EmbeddingGemma.
What We'll Build
A REST API with two endpoints:
-
/embed-mini
- Fast embeddings using all-MiniLM-L6-v2 (ONNX) -
/generate-embedding
- Embeddings using EmbeddingGemma (Candle)
Prerequisites
- Rust installed (1.70+)
- Basic understanding of async Rust
- Familiarity with REST APIs
Step 1: Set Up Your Project
Create a new Rust project and add dependencies:
cargo new embedding-api
cd embedding-api
Add these dependencies to your Cargo.toml
:
[dependencies]
axum = "0.7"
tokio = { version = "1", features = ["full"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
candle-core = "0.4"
ort = "1.16"
ndarray = "0.15"
tokenizers = "0.15"
reqwest = "0.11"
anyhow = "1.0"
Step 2: Define Your Data Structures
Create a new file src/embeddings.rs
and define the request/response types:
use serde::{Deserialize, Serialize};
#[derive(Deserialize)]
pub struct EmbeddingRequest {
text: String,
}
#[derive(Serialize)]
pub struct EmbeddingResponse {
embedding: Vec<f32>,
dimension: usize,
}
#[derive(Serialize)]
struct ErrorResponse {
error: String,
}
Step 3: Build the ONNX Embedder (MiniLM)
The MiniLM model runs via ONNX Runtime for optimal performance:
use ort::{Environment, ExecutionProvider, Session, SessionBuilder};
use tokenizers::Tokenizer;
use std::sync::Arc;
pub struct Embedder {
session: Session,
tokenizer: Arc<Tokenizer>,
}
impl Embedder {
pub fn new(model_path: &str, tokenizer_path: &str) -> anyhow::Result<Self> {
println!("Loading ONNX model from: {}", model_path);
let environment = Environment::builder()
.with_name("embedder")
.with_execution_providers([ExecutionProvider::CPU(Default::default())])
.build()?
.into_arc();
let session = SessionBuilder::new(&environment)?
.with_optimization_level(ort::GraphOptimizationLevel::Level3)?
.with_intra_threads(4)?
.with_model_from_file(model_path)?;
let tokenizer = Tokenizer::from_file(tokenizer_path)
.map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {}", e))?;
println!("✅ ONNX model loaded!");
Ok(Self {
session,
tokenizer: Arc::new(tokenizer),
})
}
}
Step 4: Implement the Embedding Logic
Add the embedding generation method to the Embedder
:
impl Embedder {
pub fn embedd(&self, text: String) -> anyhow::Result<Vec<f32>> {
// 1. Tokenize the input
let encoding = self.tokenizer
.encode(text, true)
.map_err(|e| anyhow::anyhow!("Tokenization failed: {}", e))?;
let ids = encoding.get_ids();
let mask = encoding.get_attention_mask();
let seq_len = ids.len();
// 2. Prepare inputs as i64
let input_ids: Vec<i64> = ids.iter().map(|&x| x as i64).collect();
let attention_mask: Vec<i64> = mask.iter().map(|&x| x as i64).collect();
let token_type_ids = vec![0i64; seq_len];
// 3. Create 2D arrays
let input_ids_arr = Array2::from_shape_vec((1, seq_len), input_ids)?;
let attention_mask_arr = Array2::from_shape_vec((1, seq_len), attention_mask)?;
let token_type_ids_arr = Array2::from_shape_vec((1, seq_len), token_type_ids)?;
// 4. Convert to ORT values
let input_ids_val = ort::Value::from_array(
self.session.allocator(),
&CowArray::from(&input_ids_arr.into_dyn())
)?;
let attention_mask_val = ort::Value::from_array(
self.session.allocator(),
&CowArray::from(&attention_mask_arr.into_dyn())
)?;
let token_type_ids_val = ort::Value::from_array(
self.session.allocator(),
&CowArray::from(&token_type_ids_arr.into_dyn())
)?;
// 5. Run inference
let outputs = self.session.run(vec![
input_ids_val,
attention_mask_val,
token_type_ids_val
])?;
// 6. Extract embeddings
let embeddings_tensor = outputs[0].try_extract::<f32>()?;
let embeddings = embeddings_tensor.view();
// 7. Mean pooling
let hidden_size = 384;
let mut pooled = vec![0.0f32; hidden_size];
for token_idx in 0..seq_len {
for dim_idx in 0..hidden_size {
pooled[dim_idx] += embeddings[[0, token_idx, dim_idx]];
}
}
for val in pooled.iter_mut() {
*val /= seq_len as f32;
}
// 8. Normalize to unit vector
let length: f32 = pooled.iter().map(|x| x * x).sum::<f32>().sqrt();
let normalized: Vec<f32> = pooled.iter().map(|x| x / length).collect();
Ok(normalized)
}
}
Key concepts:
- Tokenization: Converts text to numerical IDs
- Mean pooling: Averages token embeddings into a single vector
- Normalization: Ensures embeddings are comparable via cosine similarity
Step 5: Build the EmbedGemma Handler
For the Candle-based approach using Google's EmbedGemma:
use candle_core::{Device, Tensor, safetensors};
pub async fn load_model() -> anyhow::Result<(HashMap<String, Tensor>, Tokenizer, Device)> {
let device = Device::Cpu;
let model_path = std::path::Path::new("models/embeddgemma");
let tokenizer_file = model_path.join("tokenizer.json");
let tokenizer = Tokenizer::from_file(&tokenizer_file)
.map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {}", e))?;
let model_file = model_path.join("model.safetensors");
let tensors = safetensors::load(model_file, &device)?;
println!("Loaded {:?} tensors", tensors.len());
Ok((tensors, tokenizer, device))
}
Step 6: Create the Embedding Generation Function
pub async fn generate_embedding_internal(
state: &AppState,
text: String,
) -> Result<Vec<f32>, String> {
// Tokenize
let tokens = state.tokenizer
.encode(text, true)
.map_err(|e| format!("Tokenization error: {}", e))?
.get_ids()
.to_vec();
// Get embedding matrix
let embed_weights = state.tensors
.get("embed_tokens.weight")
.ok_or("embed_tokens.weight not found")?;
// Look up embeddings for each token
let mut embeddings_vec = Vec::new();
for &token_id in &tokens {
let token_tensor = Tensor::new(&[token_id as u32], &state.device)
.map_err(|e| format!("Failed to create token tensor: {}", e))?;
let token_embed = embed_weights
.index_select(&token_tensor, 0)
.map_err(|e| format!("Embedding lookup error: {}", e))?;
embeddings_vec.push(token_embed);
}
// Stack and pool
let stacked = Tensor::stack(&embeddings_vec, 0)
.map_err(|e| format!("Stacking error: {}", e))?;
let pooled = stacked.mean(0)
.map_err(|e| format!("Pooling error: {}", e))?;
// Convert to Vec<f32>
let embedding_vec: Vec<f32> = pooled
.squeeze(0)
.map_err(|e| format!("Squeeze error: {}", e))?
.to_vec1::<f32>()
.map_err(|e| format!("Tensor conversion error: {}", e))?;
// Normalize
let length: f32 = embedding_vec.iter().map(|x| x * x).sum::<f32>().sqrt();
let normalized: Vec<f32> = embedding_vec.iter().map(|x| x / length).collect();
Ok(normalized)
}
Step 7: Create API Handlers
use axum::{Json, extract::State, response::IntoResponse};
use reqwest::StatusCode;
pub async fn embed_mini(
State(state): State<AppState>,
Json(request): Json<EmbedRequest>,
) -> impl IntoResponse {
match state.embedder.embedd(request.text) {
Ok(embedding) => {
(
StatusCode::OK,
Json(EmbedResponse {
embedding,
dimension: embedding.len(),
}),
).into_response()
}
Err(e) => {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: format!("Embedding generation failed: {}", e),
}),
).into_response()
}
}
}
pub async fn generate_embedding(
State(state): State<AppState>,
Json(request): Json<EmbeddingRequest>,
) -> impl IntoResponse {
match generate_embedding_internal(&state, request.text).await {
Ok(embedding) => {
(
StatusCode::OK,
Json(EmbeddingResponse {
embedding,
dimension: embedding.len(),
}),
).into_response()
}
Err(e) => {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: format!("Embedding generation failed: {}", e),
}),
).into_response()
}
}
}
Step 8: Set Up Your Main Application
In src/main.rs
:
use axum::{Router, routing::post};
use std::sync::Arc;
#[derive(Clone)]
pub struct AppState {
embedder: Arc<Embedder>,
tensors: HashMap<String, Tensor>,
tokenizer: Arc<Tokenizer>,
device: Device,
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
// Load models
let embedder = Embedder::new(
"models/minilm/model.onnx",
"models/minilm/tokenizer.json"
)?;
let (tensors, tokenizer, device) = load_model().await?;
let state = AppState {
embedder: Arc::new(embedder),
tensors,
tokenizer: Arc::new(tokenizer),
device,
};
// Create router
let app = Router::new()
.route("/embed-mini", post(embed_mini))
.route("/generate-embedding", post(generate_embedding))
.with_state(state);
// Start server
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await?;
println!("🚀 Server running on http://localhost:3000");
axum::serve(listener, app).await?;
Ok(())
}
Step 9: Download Your Models
Create a models
directory and download:
MiniLM (ONNX):
mkdir -p models/minilm
# Download from Hugging Face: sentence-transformers/all-MiniLM-L6-v2
EmbedGemma:
mkdir -p models/embeddgemma
# Download from Hugging Face: google/embeddinggemma
Step 10: Test Your API
Start the server:
cargo run --release
Test with curl:
# MiniLM endpoint
curl -X POST http://localhost:3000/embed-mini \
-H "Content-Type: application/json" \
-d '{"text": "Hello, world!"}'
# EmbedGemma endpoint
curl -X POST http://localhost:3000/generate-embedding \
-H "Content-Type: application/json" \
-d '{"text": "Rust is amazing!"}'
Performance Tips
-
Use release mode for production:
cargo build --release
- Adjust thread count in ONNX builder based on your CPU
- Add caching for frequently requested embeddings
- Consider GPU support for larger models using CUDA execution provider
Next Steps
- Add batch processing support
- Implement model caching strategies
- Add metrics and monitoring
- Deploy with Docker
- Add authentication
Conclusion
You've built a production-ready embedding API in Rust! This setup gives you:
- Fast inference with ONNX Runtime
- Flexible model support (MiniLM, EmbedGemma)
- Type-safe request handling
- Easy integration with downstream applications
The normalized embeddings are ready for semantic search, clustering, or any similarity-based tasks.
Connect with me on LinkedIn
Questions? Drop them in the comments below!
Top comments (0)