DEV Community

Cover image for Implementing Semantic Caching using Qdrant and Rust
Shuttle
Shuttle

Posted on • Originally published at shuttle.rs

Implementing Semantic Caching using Qdrant and Rust

Hello world! Today we're going to learn about semantic caching with Qdrant, in Rust. By the end of this tutorial, you'll have a Rust application that can do the following:

  • Ingest a CSV file, turn it into an embedding with the help of an LLM and insert it into Qdrant
  • Create two collections in Qdrant - one for regular usage and one for caching
  • Utilize semantic caching for quicker access

Interested in deploying or got lost and want to find a repository with the code? You can find that here

What is semantic caching, and why use it?

In a regular data cache, we store information to enable faster retrieval later on. For example, you might have in a web service that's served behind Nginx. We can have Nginx cache either all responses, or only the most accessed endpoints. This improves performance and reduces load on your web server.

Semantic caching in this regard is quite similar. Using vector databases, we can create database collections that store the queries themselves. For example, these two questions semantically carry the same meaning:

  • What are some best practices for writing the Rust programming language?
  • What are some best practices for writing Rustlang?

We can store a copy of the query in a cache collection, with the answer as a JSON payload. If users then ask a similar question, we can retrieve the embedding and fetch the answer from the payload. This avoids us having to use an LLM to get our answer.

There are a couple of benefits to semantic caching:

  • Prompts that require long responses can see serious cost savings.
  • It's pretty easy to implement and fairly cheap - the only cost is in storage and using the embedding model
  • You can use a cheaper model than your regular embedding

Semantic caching is normally used with RAG - Retrieval Augmented Generation. RAG is a framework to allow context retrieval from pre-embedded materials. For example, CSV files or documents can be turned into embeddings using models and stored in a database. Whenever a user wants to find similar documents to a given prompt, they embed the prompt and search against it in a given database.

Of course, there are good reasons not to use semantic caching. Prompts that need differing, varied answers won't find any use for semantic caching. This is particularly relevant in generative AI usage. Fetching a stored query will reduce the creativity of the response. Regardless, if part of your pipeline is able to capitalise on semantic caching, it's a good idea to do so.

Project setup

Getting started

To get started, don't forget to use cargo shuttle init, with the Axum framework. We'll install our dependencies using the shell snippet below:

cargo add qdrant-client@1.7.0 anyhow async-openai serde serde-json \
shuttle-qdrant uuid -F uuid/v4,serde/derive
Enter fullscreen mode Exit fullscreen mode

You can find our quickstart docs here.

Setting up secrets

To set up our secrets, we'll use a Secrets.toml file located in our project root (you will need to create this manually). You can then add whatever secrets you need using the format below:

OPENAI_API_KEY = ""
QDRANT_URL = ""
QDRANT_API_KEY = ""
Enter fullscreen mode Exit fullscreen mode

Setting up Qdrant

Creating collections

Now that we can get started, we will add some more general methods for creating a regular collection as well as a cache collection, to simulate a real-world scenario (as well as a new() function to make creating the RAGSystem struct itself). We’ll create the struct first: Note here that although we're using vectors with 1536 dimensions, the number of dimensions you'll need may depend on the model you are using.

use qdrant_client::prelude::QdrantClient;
use async_openai::{config::OpenAIConfig, Client};

struct RagSystem {
    qdrant_client: QdrantClient,
    openai_client: Client<OpenAIConfig>
}

static REGULAR_COLLECTION_NAME: &str = "my-collection";
static CACHE_COLLECTION_NAME: &str = "my-collection-cached";

impl RAGSystem {
    fn new(qdrant_client: QdrantClient) -> Self {
        let openai_api_key = env::var("OPENAI_API_KEY").unwrap();

        let openai_config = OpenAIConfig::new()
            .with_api_key(openai_api_key)
            .with_org_id("qdrant-shuttle-semantic-cache");

        let openai_client = Client::with_config(openai_config);

        Self {
            openai_client,
            qdrant_client,
        }
    }
}
Enter fullscreen mode Exit fullscreen mode

Now we’ll create the methods for initialising our regular collection. Note that we’ll only need to use these once. After the collections have already been created, if we try to initialise them again we’ll get an error.

use qdrant_client::prelude::CreateCollection;
use qdrant_client::qdrant::{
    vectors_config::Config,VectorParams,
    VectorsConfig, WithPayloadSelector,
};

impl RagSystem {
    async fn create_regular_collection(&self) -> Result<()> {
        self.qdrant_client
            .create_collection(&CreateCollection {
                collection_name: REGULAR_COLLECTION_NAME.to_string(),
                vectors_config: Some(VectorsConfig {
                    config: Some(Config::Params(VectorParams {
                        size: 1536,
                        distance: Distance::Cosine.into(),
                        ..Default::default()
                    })),
                }),
                ..Default::default()
            })
            .await?;

        Ok(())
    }
}
Enter fullscreen mode Exit fullscreen mode

Next, we’ll create our cache collection. When creating this collection, note that we use Distance::Euclid instead of Distance::Cosine. Both of these can be defined as follows:

  • Distance::Cosine (or “cosine similarity”) measures how closely two vectors are pointing in the same direction. If we plot two vectors on a graph, for example, a vector located at [2,1] would be much closer to [1,1] than it would be [-1, -2]. Cosine similarity is overwhelmingly used in measuring document similarity in text analysis.
  • Distance::Euclid (or “Euclidean distance”) measures how closely two vectors are from each other - i.e., the distance from A to B where A and B are two points on a graph. Rather than trying to determine similarity, here we want to determine whether something is mostly or exactly the same.
impl RagSystem {
    async fn create_cache_collection(&self) -> Result<()> {
        self.qdrant_client
            .create_collection(&CreateCollection {
                collection_name: CACHE_COLLECTION_NAME.to_string(),
                vectors_config: Some(VectorsConfig {
                    config: Some(Config::Params(VectorParams {
                        size: 1536,
                        distance: Distance::Euclid.into(),
                        hnsw_config: None,
                        quantization_config: None,
                        on_disk: None,
                        ..Default::default()
                    })),
                }),
                ..Default::default()
            })
            .await?;

        Ok(())
    }
}
Enter fullscreen mode Exit fullscreen mode

Creating embeddings

Next, we need to create an embedding from a file input - using a CSV file as an example. To do so, we'll need to do the following:

  • Read the file inputs and parse it to a string (std::fs::read_to_string() parses to a String type automatically)
  • Chunk the file contents into appropriate amounts (here we'll do it per-row naively, for illustration)
  • Bulk embed the embeddings and add them to Qdrant

Here we're using the async-openai library to be able to create the embedding - but if you don't want to use OpenAI, you can always use fastembed-rs as an alternative or another crate of your choice that allows embedding creation.

use std::path::PathBuf;
use async_openai::types::{CreateEmbeddingRequest, EmbeddingInput};
use anyhow::Result;

impl RagSystem {
    async fn embed_and_upsert_csv_file(&self, file_path: PathBuf) -> Result<()> {
        let file_contents = std::fs::read_to_string(&file_path)?;

        // note here that we skip 1 because CSV files typically have headers
        // if you don't have any headers, you can remove it
        let chunked_file_contents: Vec<String> =
            file_contents.lines().skip(1).map(|x| x.to_owned()).collect();

        let embedding_request = CreateEmbeddingRequest {
            model: "text-embedding-ada-002".to_string(),
            input: EmbeddingInput::StringArray(chunked_file_contents.to_owned()),
            encoding_format: None, // defaults to f32
            user: None,
            dimensions: Some(1536),
        };

        let embeddings = Embeddings::new(&self.openai_client)
            .create(embedding_request)
            .await?;

        if embeddings.data.is_empty() {
            return Err(anyhow::anyhow!(
                "There were no embeddings returned by OpenAI!"
            ));
        }

        let embeddings_vec: Vec<Vec<f32>> =
            embeddings.data.into_iter().map(|x| x.embedding).collect();

        // note that we create the upsert_embedding function later on
        for embedding in embeddings_vec {
            self.upsert_embedding(embedding, file_contents.clone())
                .await?;
        }

        Ok(())
    }
}

Enter fullscreen mode Exit fullscreen mode

We'll then need to embed any further inputs to search for any matching embeddings. The embed_prompt function will look quite similar to the embedding part of our embed_and_upsert_csv_file function. However, it will instead return a Vec<f32> as we'll want to use this later to search our collection.

impl RagSystem {
    pub async fn embed_prompt(&self, prompt: &str) -> Result<Vec<f32>> {
        let embedding_request = CreateEmbeddingRequest {
            model: "text-embedding-ada-002".to_string(),
            input: EmbeddingInput::String(prompt.to_owned()),
            encoding_format: None, // defaults to f32
            user: None,
            dimensions: Some(1536),
        };

        let embeddings = Embeddings::new(&self.openai_client)
            .create(embedding_request)
            .await?;

        if embeddings.data.is_empty() {
            return Err(anyhow::anyhow!(
                "There were no embeddings returned by OpenAI!"
            ));
        }

        Ok(embeddings.data.into_iter().next().unwrap().embedding)
    }
}

Enter fullscreen mode Exit fullscreen mode

Upserting embeddings

Once we've created our embeddings, we’ll create a method for adding embeddings into Qdrant called upsert_embedding. This will deal with creating the payload for our embedding and insert it into the database. Once added to the collection, we can search our collection later on and get the associated JSON payload alongside the embedding!

The function will look like this:

use qdrant_client::prelude::PointStruct;

impl RAGSystem {
    async fn upsert_embedding(&self, embedding: Vec<f32>, file_contents: String) -> Result<()> {
        let payload = serde_json::json!({
           "document": file_contents
        })
        .try_into()
        .map_err(|x| anyhow::anyhow!("Ran into an error when converting the payload: {x}"))?;

        let points = vec![PointStruct::new(
            uuid::Uuid::new_v4().to_string(),
            embedding,
            payload,
        )];

        self.qdrant_client
            .upsert_points(REGULAR_COLLECTION_NAME.to_owned(), None, points, None)
            .await?;

        Ok(())
    }
}

Enter fullscreen mode Exit fullscreen mode

Here, we use a uuid::Uuid as a unique identifier for our embedding(s). You can also do the same thing by having a u64 counter that increases with every embedding. However, you'll want to make sure you don't accidentally overwrite your own embeddings! Inserting a new embedding with the same ID as a currently existing embedding in the collection will overwrite the embedding.

Of course, we also need to create a method for adding things to our cache. Note that our payload is different here. Instead of using the document payload field, we use answer since the payload will hold a pre-generated answer to the question.

impl RagSystem {
    pub async fn add_to_cache(&self, embedding: Vec<f32>, answer: String) -> Result<()> {
        let payload = serde_json::json!({
           "answer": answer
        })
        .try_into()
        .map_err(|x| anyhow::anyhow!("Ran into an error when converting the payload: {x}"))?;

        let points = vec![PointStruct::new(
            uuid::Uuid::new_v4().to_string(),
            embedding,
            payload,
        )];

        self.qdrant_client
            .upsert_points(CACHE_COLLECTION_NAME.to_owned(), None, points, None)
            .await?;

        Ok(())
    }
}
Enter fullscreen mode Exit fullscreen mode

Searching Qdrant collections

Having made something we can search against in Qdrant, we'll need to implement some methods for our VectorDB. We'll split this up into two methods:

  • search_regular_collection
  • search_cache_collection

When searching for an embedding, we should attempt to search our semantic cache using search_cache_collection - if it doesn't find anything, we should then use the regular search_regular_collection method to get the document, prompt OpenAI with it and then return the result as

To make our methods a little bit more error-resistant, we have used .into_iter().next() on the results. This tries to find the first item in the vector by only going to the first item in the vector. This works because we're only looking for one single embedding, but you can increase or decrease the limit as you'd like.

Once we find a match, we need to get the document key from our JSON payload associated with the embedding match and return it. We'll be using this as context in our RAG prompt later on!

use qdrant_client::qdrant::{
    with_payload_selector::SelectorOptions, SearchPoints, WithPayloadSelector
};

impl RagSystem {
        pub async fn search(&self, embedding: Vec<f32>) -> Result<String> {
        let payload_selector = WithPayloadSelector {
            selector_options: Some(SelectorOptions::Enable(true)),
        };

        let search_points = SearchPoints {
            collection_name: REGULAR_COLLECTION_NAME.to_owned(),
            vector: embedding,
            limit: 1,
            with_payload: Some(payload_selector),
            ..Default::default()
        };

        let search_result = self
            .qdrant_client
            .search_points(&search_points)
            .await
            .inspect_err(|x| println!("An error occurred while searching for points: {x}"))
            .unwrap();

        let result = search_result.result.into_iter().next();

        let Some(result) = result else {
            return Err(anyhow::anyhow!("There's nothing matching."));
        };

        Ok(result.payload.get("document").unwrap().to_string())
    }
}

Enter fullscreen mode Exit fullscreen mode

Of course, you'll also want to implement a function for searching your cache collection. Note that although the functions are mostly the same, we get the answer field from the payload instead of document for semantics.

impl RagSystem {
    pub async fn search_cache(&self, embedding: Vec<f32>) -> Result<String> {
        let payload_selector = WithPayloadSelector {
            selector_options: Some(SelectorOptions::Enable(true)),
        };

        let search_points = SearchPoints {
            collection_name: CACHE_COLLECTION_NAME.to_owned(),
            vector: embedding,
            limit: 1,
            with_payload: Some(payload_selector),
            ..Default::default()
        };

        let search_result = self
            .qdrant_client
            .search_points(&search_points)
            .await
            .inspect_err(|x| println!("An error occurred while searching for points: {x}"))?;

        let result = search_result.result.into_iter().next();

        let Some(result) = result else {
            return Err(anyhow::anyhow!("There's nothing matching."));
        };

        Ok(result.payload.get("answer").unwrap().to_string())
    }
}

Enter fullscreen mode Exit fullscreen mode

Prompting

Of course, now that everything else is done, the last thing to do is prompting! Here, you can see below that we generate a prompt that basically consists of the prompt we want, as well as the provided context. We then grab the first result from OpenAI and return the message content.

use async_openai::types::{
    ChatCompletionRequestMessage, ChatCompletionRequestSystemMessageArgs,
    ChatCompletionRequestUserMessageArgs, CreateChatCompletionRequestArgs
};

impl RagSystem {
    pub async fn prompt(&self, prompt: &str, context: &str) -> Result<String> {
        let input = format!(
            "{prompt}

            Provided context:
            {context}
            "
        );

        let res = self
            .openai_client
            .chat()
            .create(
                CreateChatCompletionRequestArgs::default()
                    .model("gpt-4o")
                    .messages(vec![
                        ChatCompletionRequestMessage::User(
                            ChatCompletionRequestUserMessageArgs::default()
                                .content(input)
                                .build()?,
                        ),
                    ])
                    .build()?,
            )
            .await
            .map(|res| {
                // We extract the first result
                match res.choices[0].message.content.clone() {
                    Some(res) => Ok(res),
                    None => Err(anyhow::anyhow!("There was no result from OpenAI")),
                }
            })??;

        println!("Retrieved result from prompt: {res}");

        Ok(res)
    }
}

Enter fullscreen mode Exit fullscreen mode

Using Qdrant in a Rust web service

Let's have a quick look at a real world example. Below is a HTTP endpoint for the Axum framework that takes our RAGSystem as application state. It’ll embed the prompt and attempt to search the cache. If there’s no result, it searches in the regular collection for a match. The resulting document payload is added to an augmented prompt, and the question and answer are added to the cache. Finally, a response is returned from the endpoint.

use axum::{Json, extract::State, response::IntoResponse, http::StatusCode};
use serde::Deserialize;

#[derive(Deserialize)]
struct Prompt {
    prompt: String,
}

async fn prompt(
    State(state): State<RAGSystem>,
    Json(prompt): Json<Prompt>,
) -> Result<impl IntoResponse, impl IntoResponse> {
    let embedding = match state.embed_prompt(&prompt.prompt).await {
        Ok(embedding) => embedding,
        Err(e) => {
            return Err((
                StatusCode::INTERNAL_SERVER_ERROR,
                format!("An error occurred while embedding the prompt: {e}"),
            ))
        }
    };

    if let Ok(answer) = state.search_cache(embedding.clone()).await {
        return Ok(answer);
    }

    let search_result = match state.search(embedding.clone()).await {
        Ok(res) => res,
        Err(e) => {
            return Err((
                StatusCode::INTERNAL_SERVER_ERROR,
                format!("An error occurred while prompting: {e}"),
            ))
        }
    };

    let llm_response = match state.prompt(&prompt.prompt, &search_result).await {
        Ok(prompt_result) => prompt_result,
        Err(e) => {
            return Err((
                StatusCode::INTERNAL_SERVER_ERROR,
                format!("Something went wrong while prompting: {e}"),
            ))
        }
    };

    if let Err(e) = state.add_to_cache(embedding, &llm_response).await {
        return Err((
            StatusCode::INTERNAL_SERVER_ERROR,
            format!("Something went wrong while adding item to the cache: {e}"),
        ));
    };

    Ok(llm_response)
}

Enter fullscreen mode Exit fullscreen mode

The last thing to do is setting up our main function. Note that we add the shuttle_qdrant::Qdrant annotation to our main function, allowing us to provision a Qdrant instance locally with Docker automatically on a local run. In production though, we'll need the cloud_url and api_key keys filled out.

#[shuttle_runtime::main]
async fn main(
    #[shuttle_qdrant::Qdrant(
        cloud_url = "{secrets.QDRANT_URL}",
        api_key = "{secrets.QDRANT_API_KEY}"
    )]
    qdrant: QdrantClient,
    #[shuttle_runtime::Secrets] secrets: SecretStore,
) -> shuttle_axum::ShuttleAxum {
    secrets.into_iter().for_each(|x| env::set_var(x.0, x.1));

    let rag = RAGSystem::new(qdrant);

    let setup_required = true;

    if setup_required {
        rag.create_regular_collection().await?;
        rag.create_cache_collection().await?;

        rag.embed_csv_file("test.csv".into()).await?;
    }

    let rtr = Router::new().route("/prompt", post(prompt)).with_state(rag);

    Ok(rtr.into())
}

Enter fullscreen mode Exit fullscreen mode

Deploying

To deploy, all you need to do is use cargo shuttle deploy (with the --ad flag if on a Git branch with uncommitted changes) and wait for it to deploy! Once you've deployed, any further deploys needed will only need to re-compile your application (and any extra dependencies if added) then it'll be done much, much faster.

Extending this example

Want to extend this example? Here's a couple ways you can do that.

Use a cheaper model for semantic caching

While using a high-performance model is great and all, one thing that we want to save on in particular is costs. One thing that we can do here to save tokens is by using a cheaper model and asking the model if one question is semantically the same as another. Here's a prompt you can use:

Are these two questions semantically the same? Answer either 'Yes' or 'No'. Do not answer with anything else. If you don't know the answer, say 'I don't know'.

Question 1: <question 1 goes here>
Question 2: <question 2 goes here>

Enter fullscreen mode Exit fullscreen mode

Smaller payload indexes

It should be noted of course that while our example does work, one thing you might need to think about is payload indexes or associated data connected to a particular embedding. If you're inserting the whole file contents as the payload for every single embedding in a large file, chances are you are going to increase your resource usage quite rapidly. You can mitigate this by only inserting a relevant slice of the file per embedding (so for example in this case, it might be the row).

Finishing up

Thanks for reading! By using semantic caching, we can create a much more performant RAG system that saves on both time and costs.

Read more:

Top comments (0)