DEV Community

Dheeraj Gopinath
Dheeraj Gopinath

Posted on

RAG using LLMSmith and FastAPI

What is LLMSmith?

LLMSmith is a lightweight Python library designed for developing functionalities powered by Large Language Models (LLMs). It allows developers to integrate generative AI capabilities into all sorts of applications. In this case, we will be creating an RAG based chatbot using LLMSmith and expose it as an API endpoint in a FastAPI app.

LLMSmith repo: https://github.com/dheerajgopi/llmsmith

FYI.. I’m the author of this library :)

Now, lets get to the interesting part.

Implement the RAG functionality using LLMSmith

This is what we will be doing here.

  • Pre-process the user’s query for stripping out info that is irrelevant for retrieval using OpenAI LLM.
  • Retrieve relevant documents from Qdrant vector DB.
  • Rerank the retrieved documents so that they are ordered based on semantic relevance.
  • Generate answer using OpenAI LLM. The reranked documents are passed as context in the prompt.

The below piece of code uses LLMSmith library for implementing the above mentioned RAG flow.

from textwrap import dedent
import cohere
from fastembed import TextEmbedding
import openai
from qdrant_client import AsyncQdrantClient
from llmsmith.task.retrieval.vector.qdrant import QdrantRetriever
from llmsmith.reranker.cohere import CohereReranker
from llmsmith.task.textgen.openai import OpenAITextGenTask, OpenAITextGenOptions
from llmsmith.job.job import SequentialJob

from rag_llmsmith_fastapi.config import settings


preprocess_prompt = (
    dedent("""
    Convert the natural language query from a user into a query for a vectorstore.
    In this process, you strip out information that is not relevant for the retrieval task.
    Return only the query converted for retrieval and nothing else.
    Here is the user query: {{root}}""")
    .strip("\n")
    .replace("\n", " ")
)


class RAGService:
    def __init__(
        self,
        llm_client: openai.AsyncOpenAI,
        vectordb_client: AsyncQdrantClient,
        reranker_client: cohere.AsyncClient,
        embedder: TextEmbedding,
        **_,
    ) -> None:
        self.llm_client = llm_client
        self.vectordb_client = vectordb_client
        self.reranker_client = reranker_client
        self.embedder = embedder

    async def chat(self, user_prompt):
        # Create Cohere reranker
        reranker = CohereReranker(client=self.reranker_client)

        # Embedding function to be passed into the Qdrant retriever
        def embedding_func(x):
            return list(self.embedder.query_embed(x))

        # Define the Qdrant retriever task. The embedding function and reranker are passed as parameters.
        retrieval_task = QdrantRetriever(
            name="qdrant-retriever",
            client=self.vectordb_client,
            collection_name=settings.QDRANT.COLLECTION_NAME,
            embedding_func=embedding_func,
            embedded_field_name="description",  # name of the field in the document on which embeddedings are created while uploading data to the Qdrant collection
            reranker=reranker,
        )

        # Define the OpenAI LLM task for rephrasing the query
        preprocess_task = OpenAITextGenTask(
            name="openai-preprocessor",
            llm=self.llm_client,
            llm_options=OpenAITextGenOptions(model="gpt-4-turbo", temperature=0),
        )

        # Define the OpenAI LLM task for answering the query
        answer_generate_task = OpenAITextGenTask(
            name="openai-answer-generator",
            llm=self.llm_client,
            llm_options=OpenAITextGenOptions(model="gpt-4-turbo", temperature=0),
        )

        # define the sequence of tasks
        # {{root}} is a special placeholer in `input_template` which will be replaced with the prompt entered by the user (`user_prompt`).
        # The placeholder {{qdrant-retriever.output}} will be replaced with the output from Qdrant DB retriever task.
        # The placeholder {{openai-preprocessor.output}} will be replaced with the output from the query preprocessing task done by OpenAI LLM.
        job: SequentialJob[str, str] = (
            SequentialJob()
            .add_task(
                preprocess_task,
                input_template=preprocess_prompt,
            )
            .add_task(retrieval_task, input_template="{{openai-preprocessor.output}}")
            .add_task(
                answer_generate_task,
                input_template="Answer the question based on the context: \n\n QUESTION:\n{{root}}\n\nCONTEXT:\n{{qdrant-retriever.output}}",
            )
        )

        # Now, run the job
        await job.run(user_prompt)

        return job.task_output("openai-answer-generator")
Enter fullscreen mode Exit fullscreen mode

Using LLMSmith makes it quite easy to implement such LLM based functionalities. In the above code, we first create the following tasks:

  • QdrantRetriever — To retrieve documents from Qdrant.
  • CohereReranker — Rerank documents based on sematic relevance (passed into QdrantRetriever)
  • OpenAITextGenTask — To execute LLM calls. Used for both pre-processing and answer generation in our case.

After the task definitions, all these tasks are run sequentially using SequentialJob. In an LLMSmith job, the output of previous tasks can be easily passed into the next task via placeholders using input_template parameter.

Integrate the RAGService into a FastAPI endpoint

This is easy! Just call the chat method from the RAGService instance in your route handler function.

from fastapi import APIRouter

from llmsmith.task.models import TaskOutput

from rag_llmsmith_fastapi.chat.model import ChatRequest, ChatResponse
from rag_llmsmith_fastapi.chat.service import RAGService


class ChatController:
    def __init__(self, rag_svc: RAGService) -> None:
        self.rag_svc = rag_svc
        self.router: APIRouter = APIRouter(tags=["Chat endpoint"], prefix="/api")

        self.router.add_api_route(
            path="/chat",
            endpoint=self.chat,
            methods=["POST"],
        )

    async def chat(self, req_body: ChatRequest):
        rag_response: TaskOutput = await self.rag_svc.chat(req_body.content)
        return ChatResponse(content=rag_response.content)
Enter fullscreen mode Exit fullscreen mode

The complete code can be found in this repo.

For running the application, clone the repository from Github and follow the README.md file.

Contributors Welcome

All contributions (no matter if small) to LLMSmith are always welcome. To see how you can help and where to start see CONTRIBUTING.md.

Top comments (0)