DEV Community

Cover image for Doc Sage🧙‍♂️- Create a Smart RAG App with LangChain and Streamlit
Ngonidzashe Nzenze
Ngonidzashe Nzenze

Posted on • Edited on

Doc Sage🧙‍♂️- Create a Smart RAG App with LangChain and Streamlit

Ever been reading through a PDF document and thought, "Hmm, if only there was a way I could quickly extract the relevant information"? That would save quite a lot of time. It's been a few years since large language models (LLMs) have been introduced, revolutionizing the way we interact with text data.

LLMs are trained on vast amounts of data from the internet and other text sources, making them highly effective at many general-purpose tasks. However, in certain cases, you might want to train or augment an LLM with your own data to make the responses more relevant to your needs. In this article, I’ll show you how to create an app that uses Retrieval-Augmented Generation (RAG) to answer questions specific to particular documents or web pages. But first...

What is RAG?

RAG stands for Retrieval-Augmented Generation. It is a type of natural language processing framework that combines the benefits of retrieval-based and generative models.

These models:

  1. Retrieve relevant info from a database or knowledge graph
  2. Use this information to generate a more accurate response

Pretty useful, especially when you're handling rare topics that the LLM may not know about. I took quite a lot of inspiration from Google's NotebookLM, but this is not going to be anywhere close to what NotebookLM does, just a gentle introduction to help you get a bit of an understanding of what's happening and maybe appreciate RAG.

With that, lets proceed with the...

Setup

For this application, we're going to be making use of the following packages:

  • langchain - to link us to the open ai large language models
  • chroma db - vector store we're gonna be using to store our information
  • streamlit - for the interface because of how easy it is to setup

Before we continue, let me give you a sneak peek at the final product:

You can find the the code for this application on my github.

Install the packages as follows:
pip install streamlit
pip install langchain langchain-community langchain-openai langchain-chroma

We're also going to need BeautifulSoup so install it as follows:
pip install bs4

The application has two databases. The vector database to store the documents and a relational database to store the chats. sqlite3 is simple to setup hence we are going to be using it in this tutorial. Feel free to make use of any other databases.

And with that, we're ready to begin.

Database design

Our relational database is going to have 3 tables:

  • The chat table - which will store the chat names
  • The sources table - which will store the different sources we have loaded into the vector store
  • The messages table - which will store the messages between the human and the AI

Create a file named create_relational_db.py and add the following code:

import sqlite3

# Connect to SQLite database (or create it if it doesn't exist)
conn = sqlite3.connect("doc_sage.sqlite")
cursor = conn.cursor()

# Create 'chat' table
cursor.execute(
    """
    CREATE TABLE IF NOT EXISTS chat (
        id INTEGER PRIMARY KEY AUTOINCREMENT,
        title TEXT NOT NULL,
        created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
        updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
    )
"""
)

# Create 'sources' table
cursor.execute(
    """
    CREATE TABLE IF NOT EXISTS sources (
        id INTEGER PRIMARY KEY AUTOINCREMENT,
        name TEXT NOT NULL,
        source_text TEXT,
        type TEXT DEFAULT "document",
        chat_id INTEGER,
        FOREIGN KEY (chat_id) REFERENCES chat(id)
    )
"""
)


# Create 'messages' table
cursor.execute(
    """
    CREATE TABLE IF NOT EXISTS messages (
        id INTEGER PRIMARY KEY AUTOINCREMENT,
        chat_id INTEGER NOT NULL,
        sender TEXT NOT NULL,
        content TEXT NOT NULL,
        timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
        FOREIGN KEY(chat_id) REFERENCES chat(id)
    );
"""
)


# Commit the transaction
conn.commit()


# Close the connection
conn.close()

print("Tables created successfully.")
Enter fullscreen mode Exit fullscreen mode

The souces table stores the name of the document or the link address. The type field indicates whether the source is a document (file) or webpage. It can only have two values document or link with document as the default.

The messages table stores the messages between the user and the LLM. It is linked to the chat table by a foreign key on chat_id. The sender is either ai or user, which is a way to keep track of who send which message. The content is the actual content of the message from either the user or the ai.

Create the database and tables by running:
python create_relational_db.py

Now let us move on to the functions that are going to be operating on the database.

Create another file named db.py and add the following code:

import sqlite3

# Connect to SQLite database
def connect_db():
    return sqlite3.connect("doc_sage.sqlite")

# CRUD Operations for 'chat' table
def create_chat(title):
    conn = connect_db()
    cursor = conn.cursor()
    cursor.execute("INSERT INTO chat (title) VALUES (?)", (title,))
    chat_id = cursor.lastrowid
    conn.commit()
    conn.close()
    return chat_id


def list_chats():
    conn = connect_db()
    cursor = conn.cursor()
    cursor.execute("SELECT * FROM chat ORDER BY created_at DESC")
    chats = cursor.fetchall()
    conn.close()
    return chats


def read_chat(chat_id):
    conn = connect_db()
    cursor = conn.cursor()
    cursor.execute("SELECT * FROM chat WHERE id = ?", (chat_id,))
    result = cursor.fetchone()
    conn.close()
    return result


def update_chat(chat_id, new_title):
    conn = connect_db()
    cursor = conn.cursor()
    cursor.execute(
        "UPDATE chat SET title = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ?",
        (new_title, chat_id),
    )
    conn.commit()
    conn.close()


def delete_chat(chat_id):
    conn = connect_db()
    cursor = conn.cursor()
    cursor.execute("DELETE FROM chat WHERE id = ?", (chat_id,))
    conn.commit()
    conn.close()


def create_source(name, source_text, chat_id, source_type="document"):
    conn = connect_db()
    cursor = conn.cursor()
    cursor.execute(
        "INSERT INTO sources (name, source_text, chat_id, type) VALUES (?, ?, ?, ?)",
        (name, source_text, chat_id, source_type),
    )
    conn.commit()
    conn.close()


def read_source(source_id):
    conn = connect_db()
    cursor = conn.cursor()
    cursor.execute("SELECT * FROM sources WHERE id = ?", (source_id,))
    result = cursor.fetchone()
    conn.close()
    return result

def update_source(source_id, new_name, new_source_text):
    conn = connect_db()
    cursor = conn.cursor()
    cursor.execute(
        "UPDATE sources SET name = ?, source_text = ? WHERE id = ?",
        (new_name, new_source_text, source_id),
    )
    conn.commit()
    conn.close()


def list_sources(chat_id, source_type=None):
    conn = connect_db()
    cursor = conn.cursor()
    if source_type:
        cursor.execute(
            "SELECT * FROM sources WHERE chat_id = ? AND type = ?",
            (chat_id, source_type),
        )
    else:
        cursor.execute("SELECT * FROM sources WHERE chat_id = ?", (chat_id,))
    sources = cursor.fetchall()
    conn.close()
    return sources


def delete_source(source_id):
    conn = connect_db()
    cursor = conn.cursor()
    cursor.execute("DELETE FROM sources WHERE id = ?", (source_id,))
    conn.commit()
    conn.close()


# CRUD Operations for 'messages' table
def create_message(chat_id, sender, content):
    conn = connect_db()
    cursor = conn.cursor()
    cursor.execute(
        "INSERT INTO messages (chat_id, sender, content) VALUES (?, ?, ?)",
        (chat_id, sender, content),
    )
    conn.commit()
    conn.close()


def get_messages(chat_id):
    conn = connect_db()
    cursor = conn.cursor()
    cursor.execute(
        "SELECT sender, content FROM messages WHERE chat_id = ? ORDER BY timestamp ASC",
        (chat_id,),
    )
    messages = cursor.fetchall()
    conn.close()
    return messages


def delete_messages(chat_id):
    conn = connect_db()
    cursor = conn.cursor()
    cursor.execute("DELETE FROM messages WHERE chat_id = ?", (chat_id,))
    conn.commit()
    conn.close()
Enter fullscreen mode Exit fullscreen mode

The above code is responsible for all the operations on the sqlite3 database.

RAG Functions

In this section, we are going to create the functions that operate on the vector store, from loading documents to retrieving them and generating the responses. This is where most of the magic happens. Get some coffee because this section is rather long and a little more complicated🙂. But worry not, I will explain as simply as possible.

Create a file named vector_functions.py and add the following lines:

import os
from langchain_chroma import Chroma
from langchain_openai import OpenAIEmbeddings
from langchain_core.documents import Document
from langchain_text_splitters import CharacterTextSplitter
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_community.document_loaders import (
    TextLoader,
    CSVLoader,
    PyPDFLoader,
    Docx2txtLoader,
    UnstructuredHTMLLoader,
    UnstructuredMarkdownLoader,
)

import environ
env = environ.Env()

# reading .env file
environ.Env.read_env()

llm = ChatOpenAI(
    model="gpt-4o-mini",
    api_key=env("OPENAI_API_KEY"),
)

embeddings = OpenAIEmbeddings(
    api_key=env("OPENAI_API_KEY"),
)

text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
Enter fullscreen mode Exit fullscreen mode

In this example, I am making use of the python-environ module to handle the api keys so make sure you have it installed with:

pip install python-environ

Create an .env file and add your open ai key e.g:
OPENAI_API_KEY=Your-API-Key

Initialize the large language model with the ChatOpenAI class. In this example we are going to be making use of gpt-4o-mini. You can make use other models Like Anthropic's Claude. Make sure to refer to the documentation for more information.

OpenAIEmbeddings is a class that generates embeddings — a way to convert text into a numeric format so that AI models can process it more easily.

CharacterTextSplitter breaks up long text into smaller chunks to make it easier for the AI model to handle. chunk_size=1000 means each text chunk will contain up to 1000 characters. chunk_overlap=0 means there’s no overlap between chunks; each chunk is independent and contains a unique section of the text. This approach is often useful when working with long articles, documents, or books that need to be processed in parts.

Next we create a function that is going to be responsible for loading different doctypes:

def load_document(file_path: str) -> list[Document]:
    """
    Load a document from a file path.
    Supports .txt, .pdf, .docx, .csv, .html, and .md files.

    Args:
    file_path (str): Path to the document file.

    Returns:
    list[Document]: A list of Document objects.

    Raises:
    ValueError: If the file type is not supported.
    """
    _, file_extension = os.path.splitext(file_path)

    if file_extension == ".txt":
        loader = TextLoader(file_path)
    elif file_extension == ".pdf":
        loader = PyPDFLoader(file_path)
    elif file_extension == ".docx":
        loader = Docx2txtLoader(file_path)
    elif file_extension == ".csv":
        loader = CSVLoader(file_path)
    elif file_extension == ".html":
        loader = UnstructuredHTMLLoader(file_path)
    elif file_extension == ".md":
        loader = UnstructuredMarkdownLoader(file_path)
    else:
        raise ValueError(f"Unsupported file type: {file_extension}")

    return loader.load()
Enter fullscreen mode Exit fullscreen mode

The load_document function returns a list of Document objects, which are later split into texts by a text splitter and then saved in the vector store.

Install the following packages that are required by some of the document loaders:
pip install pypdf unstructured docx2txt Markdown

Next we add the following code to create and load collections in our vector store:

# vector_functions.py

def create_collection(collection_name, documents):
    """
    Create a new Chroma collection from the given documents.

    Args:
    collection_name (str): The name of the collection to create.
    documents (list): A list of documents to add to the collection.

    Returns:
    None

    This function splits the documents into texts, creates a new Chroma collection,
    and persists it to disk.
    """

    # Split the documents into smaller text chunks
    texts = text_splitter.split_documents(documents)
    persist_directory = "./persist"

    # Create a new Chroma collection from the text chunks
    try:
        vectordb = Chroma.from_documents(
            documents=texts,
            embedding=embeddings,
            persist_directory=persist_directory,
            collection_name=collection_name,
        )
    except Exception as e:
        print(f"Error creating collection: {e}")
        return None

    return vectordb


def load_collection(collection_name):
    """
    Load an existing Chroma collection.

    Args:
    collection_name (str): The name of the collection to load.

    Returns:
    Chroma: The loaded Chroma collection.
    This function loads a previously created Chroma collection from disk.
    """
    persist_directory = "./persist"
    
    # Load the Chroma collection from the specified directory
    vectordb = Chroma(
        persist_directory=persist_directory,
        embedding_function=embeddings,
        collection_name=collection_name,
    )

    return vectordb



def add_documents_to_collection(vectordb, documents):
    """
    Add documents to the vector database collection.

    Args:
        vectordb: The vector database object to add documents to.
        documents: A list of documents to be added to the collection.
    This function splits the documents into smaller chunks, adds them to the
    vector database, and persists the changes.
    """

    # Split the documents into smaller text chunks
    texts = text_splitter.split_documents(documents)

    # Add the text chunks to the vector database
    vectordb.add_documents(texts)

    return vectordb
Enter fullscreen mode Exit fullscreen mode

We have created three function; create_collection, load_collection and add_documents_to_collection.

The create_collection function receives the collection_name and documents as arguments. The documents argument is the list of Document objects loaded using the load_document function. After loading the documents, we split the texts, 1000 characters per split(this of course can be changed). We initialize a persist_directory which is where the collections are going to be stored. Finally, we create a vector store from the given documents.

The load_collection function loads any collection that was created and returns it.

The add_documents_to_collection lets us add new documents to a collection. As we are using the app, we start with an empty collection. When we upload a document, it is saved to the database using this function. The documents are first loaded using the load_document function. This function returns the list of Document objects. The system then loads a collection using the load_collection function and the documents are finally added to the collection using the add_documents_to_collection function.

Next, a retriever is required. This is a function that will return the relevant results based on the search. This retriever queries the vector store and returns the most relevant results:

# vector_functions.py

def load_retriever(collection_name, score_threshold: float = 0.6):
    """
    Create a retriever from a Chroma collection with a similarity score threshold.

    Args:
    collection_name (str): The name of the collection to use.
    score_threshold (float): The minimum similarity score threshold for retrieving documents.
                           Documents with scores below this threshold will be filtered out.
                           Defaults to 0.6.
    Returns:
    Retriever: A retriever object that can be used to query the collection with similarity
              score filtering.
    This function loads a Chroma collection and creates a retriever from it that will only
    return documents meeting the specified similarity score threshold.

    """

    # Load the Chroma collection
    vectordb = load_collection(collection_name)

    # Create a retriever from the collection with specified search parameters
    retriever = vectordb.as_retriever(
        search_type="similarity_score_threshold",
        search_kwargs={"score_threshold": score_threshold},
    )
    return retriever
Enter fullscreen mode Exit fullscreen mode

The load_retriever function accepts two arguments: collection_name and score_threshold.

Some of the methods we can use to perform searches in the vector store are:

  • Similarity Search: Finds the items that are "closest" to your search queries. This closeness is measured by metrics such as cosine similarity or Euclidean distance. Usually you have to specify that you want k number of items.
  • Similarity Score Threshold: Returns results with a similarity score that is above a given threshold. This is useful when you want the most relevant matches without unrelated results.
  • Maximum Margin Ranking: This one focuses on maximizing the gap between relevant and irrelevant items. This approach helps in clearly ordering results, so the most relevant items appear at the top while less relevant ones are pushed lower. We are making use of the Similarity Score Threshold method, however, you can also make use of the other methods to see how well they perform.

And now we add the last function before moving on to the interface:

# vector_functions.py

def generate_answer_from_context(retriever, question: str):
    """
    Ask a question and get an answer based on the provided context.

    Args:
        retriever: A retriever object to fetch relevant context.
        question (str): The question to be answered.

    Returns:
        str: The answer to the question based on the retrieved context.
    """

    # Define the message template for the prompt
    message = """
    Answer this question using the provided context only.
    {question}

    Context:
    {context}
    """

    # Create a chat prompt template from the message
    prompt = ChatPromptTemplate.from_messages([("human", message)])

    # Create a RAG (Retrieval-Augmented Generation) chain
    # This chain retrieves context, passes through the question,
    # formats the prompt, and generates an answer using the language model
    rag_chain = {"context": retriever, "question": RunnablePassthrough()} | prompt | llm

    # Invoke the RAG chain with the question and return the generated content
    return rag_chain.invoke(question).content
Enter fullscreen mode Exit fullscreen mode

The purpose of this function is to return an answer that the LLM generates using data retrieved by retriever. It accepts 2 arguments:

  • retriever - to find relevant text
  • question - a string containing the user's query.

We define a message template that structures how we want the question and context to look before we send them to the model. {question} and {context} are placeholders that are going to be filled later on with actual values.

ChatPromptTemplate.from_messages([("human", message)]) creates the a structured prompt using our template. This is useful as it guides the models response helping it generate meaningful output. You can read the documentation for more information.

Next, we create the RAG chain:

rag_chain = {"context": retriever, "question": RunnablePassthrough()} | prompt | llm
Enter fullscreen mode Exit fullscreen mode

Lets breakdown what's happening:

  1. {"context": retriever, "question": RunnablePassthrough()}
    • This dictionary stores input components for the pipeline. We set the "context" to retriever, which is going to return the relevant text as context.
    • "question" uses RunnablePassthrough(), meaning it passes the question along the chain without modifying it. You can get more information about how it works from the documentation.
  2. | prompt
    • This part pipes the output from the previous dictionary into the prompt object which formats these inputs into a structured prompt.
  3. | llm
    • The final stage sends the prompt to the LLM which then generates the response.

rag_chain.invoke(question) runs the entire chain, passing in the question and returning the answer based on the given context.

All that's left now is to create the interface and connect it with the different functions...

User Interface

The application has 2 pages:

  • The chats home page
  • The chat page

Lets start by creating the chats home page. Create a file named chats.py and add the following code:


import streamlit as st
import os, time, math
import requests
from bs4 import BeautifulSoup
from langchain_core.documents import Document
from db import (
    read_chat,
    create_chat,
    list_chats,
    delete_chat,
    create_message,
    get_messages,
    create_source,
    list_sources,
    delete_source,
)

from vector_functions import (
    load_document,
    create_collection,
    load_retriever,
    generate_answer_from_context,
    add_documents_to_collection,
    load_collection,
)

def chats_home():
    st.markdown(
        "<h1 style='text-align: center;'>DocSage🧙‍♂️</h1>", unsafe_allow_html=True
    )

    with st.container(border=True):
        col1, col2 = st.columns([0.8, 0.2])

        with col1:
            chat_title = st.text_input(
                "Chat Title", placeholder="Enter Chat Title", key="chat_title"
            )

        with col2:
            st.markdown("<br>", unsafe_allow_html=True)  # Add vertical space
            if st.button("Create Chat", type="primary"):
                if chat_title:
                    chat_id = create_chat(chat_title)
                    st.success(f"Created new chat: {chat_title}")
                    st.query_params.from_dict({"chat_id": chat_id})
                    st.rerun()
                else:
                    st.warning("Please enter a chat title")

    with st.container(border=True):
        st.subheader("Previous Chats")

        # get previous chats from db
        previous_chats = list_chats()

        # Pagination settings
        chats_per_page = 5
        total_pages = math.ceil(len(previous_chats) / chats_per_page)

        # Get current page from session state
        if "current_page" not in st.session_state:
            st.session_state.current_page = 1

        # Calculate start and end indices for the current page
        start_idx = (st.session_state.current_page - 1) * chats_per_page
        end_idx = start_idx + chats_per_page

        # Display chats for the current page
        for chat in previous_chats[start_idx:end_idx]:
            chat_id, chat_title = chat[0], chat[1]
            with st.container(border=True):
                col1, col2, col3 = st.columns([0.6, 0.2, 0.2])

                with col1:
                    st.markdown(f"**{chat_title}**")
                with col2:
                    if st.button("📂 Open", key=f"open_{chat_id}"):
                        st.query_params.from_dict({"chat_id": chat_id})
                        st.rerun()

                with col3:
                    if st.button("🗑️ Delete", key=f"delete_{chat_id}"):
                        delete_chat(chat_id)
                        st.success(f"Deleted chat: {chat_title}")
                        st.rerun()


        # Pagination controls
        col1, col2, col3 = st.columns([1, 2, 1])

        with col1:
            if st.button("Previous") and st.session_state.current_page > 1:
                st.session_state.current_page -= 1
                st.rerun()
        with col2:
            st.write(f"Page {st.session_state.current_page} of {total_pages}")
        with col3:
            if st.button("Next") and st.session_state.current_page < total_pages:
                st.session_state.current_page += 1
                st.rerun()



def main():
       chats_home()

if __name__ == "__main__":
    main()
Enter fullscreen mode Exit fullscreen mode

The chat page is going to consist of a text input that allows us to create chats and a list of the chats we have created with pagination:

Image description

When a user adds a chat title and clicks the create chat button, the following code is executed:

if st.button("Create Chat", type="primary"):
    if chat_title:
        chat_id = create_chat(chat_title)
        st.success(f"Created new chat: {chat_title}")
        st.query_params.from_dict({"chat_id": chat_id})
        st.rerun()
    else:
        st.warning("Please enter a chat title")
Enter fullscreen mode Exit fullscreen mode

If the user text input is not empty, the create_chat function is called with the chat_title as the argument. This function returns the id of the created chat. A success message is shown and we set the query_params using the from_dict method to the chat_id. This going to allow us to navigate to the chat page for the newly created chat. If the chat_title is empty, a warning is simply displayed.

Now to add the chat page, add the following code to chats.py:


def stream_response(response):
    """
    Stream a response word by word with a delay between each word.
    Args:
        response (str): The text response to stream
    Yields:
        str: Individual words from the response with a space appended
    Note:
        Adds a 50ms delay between each word to create a typing effect
    """

    # Split response into words and stream each one
    for word in response.split():
        # Yield the word with a space and pause briefly
        yield word + " "
        time.sleep(0.05)
Enter fullscreen mode Exit fullscreen mode

The stream_response method is going to simulate the chat GPT model streaming response, writing each word after a 50ms delay rather that spilling out all the text in one go.

Next, add the function for the chat page:


def chat_page(chat_id):
    """
        Renders the main chats page where users can:
        - Create new chats with titles
        - View and manage previous chats
        - Navigate through paginated chat history

        The page displays a header, chat creation form, and list of existing chats
        with options to open each chat.
        """
        chat = read_chat(chat_id)

        if not chat:
            st.error("Chat not found")
            return

        # Retrieve messages from DB
        messages = get_messages(chat_id)

        # Display messages
        if messages:
            for sender, content in messages:
                if sender == "user":
                    with st.chat_message("user"):
                        st.markdown(content)
                elif sender == "ai":
                    with st.chat_message("assistant"):
                        st.markdown(content)
        else:
            st.write("No messages yet. Start the conversation!")
Enter fullscreen mode Exit fullscreen mode

The function first retrieves the chat using the read_chat method. If the chat is not found, it returns an error message. If the chat exists, it uses the get_messages method to retrieve the messages exchanged between the user and the LLM. These messages are added to the page. If no messages are found, a simple text urging the user to start a conversation with the bot is displayed.

Now add the following code to the chat_page function:

def chat_page(chat_id):
    # Rest of the code ...

    # Add a text input for new messages
    prompt = st.chat_input("Type your message here...")
    if prompt:
        # Save user message
        create_message(chat_id, "user", prompt)

        # Display user message
        with st.chat_message("user"):
            st.markdown(prompt)

        # Get AI response
        # Load retriever for the chat context
        collection_name = f"chat_{chat_id}"
        if os.path.exists(f"./persist"):
            retriever = load_retriever(collection_name=collection_name)
        else:
            retriever = None

        # Ask question using the retriever
        response = (
            generate_answer_from_context(retriever, prompt)
            if retriever
            else "I need some context to answer that question."
        )

        # Save AI response
        create_message(chat_id, "ai", response)
    
        # Display AI response
        with st.chat_message("assistant"):
            st.write_stream(stream_response(response))

        st.rerun()
Enter fullscreen mode Exit fullscreen mode

We use the streamlit chat_input to get the user's prompt. The prompt is then saved to the messages table using the create_message function and the message is displayed on the page using st.chat_message.

The next section gets the AI response as follows:

  • creates the collection name using the chat_id
  • checks if the persist directory exists which is where the collections are stored
  • If the persist directory exists, a retriever object is created using load_retriever method.
  • If the persist directory does not exist, the retriever object is set to None
  • If the retriever object is not None, the generate_answer_from_context method is used to get a response from the model. Remember the retriever first retrieves relevant texts and feeds them to the LLM. The LLM uses these texts to come up with an appropriate response.

The LLMs response is then saved to the model using the create_message function and written on to the page using st.write_stream and the stream_response function. Finally, the page is reloaded with the rerun method to display the text correctly.

Sidebar

In this section, we look at creating the sidebar as well as saving documents and links to the vector store. Add the following code to the chat_page function:


def chat_page(chat_id):
    # rest of the code ...

    # Sidebar
    with st.sidebar:
        # Button to return to the main chats page
        if st.button("Back to Chats"):
            st.query_params.clear()
            st.rerun()
            
        # Chat name
        st.subheader(f"{chat[1]}")

        # Documents Section
        st.subheader("📑 Documents")

        # Get all "document" type sources
        documents = list_sources(chat_id, source_type="document")

        if documents:
            # list the documents
            for doc in documents:
                doc_id = doc[0]
                doc_name = doc[1]
                col1, col2 = st.columns([0.8, 0.2])
                with col1:
                    st.write(doc_name)
                with col2:
                    if st.button("", key=f"delete_doc_{doc_id}"):
                        delete_source(doc_id)
                        st.success(f"Deleted document: {doc_name}")
                        st.rerun()
        else:
            st.write("No documents uploaded.")

        uploaded_file = st.file_uploader("Upload Document", key="file_uploader")

        if uploaded_file:

            # Save document content to database
            with st.spinner("Processing document..."):
                temp_dir = "temp_files"
                os.makedirs(temp_dir, exist_ok=True)
                temp_file_path = os.path.join(temp_dir, uploaded_file.name)
                
                with open(temp_file_path, "wb") as f:
                    f.write(uploaded_file.getbuffer())

                # Load document
                document = load_document(temp_file_path)
                
                # Create or update collection for this chat
                collection_name = f"chat_{chat_id}"

                if not os.path.exists(f"./persist/{collection_name}"):
                    vectordb = create_collection(collection_name, document)
                else:
                    vectordb = load_collection(collection_name)
                    vectordb = add_documents_to_collection(vectordb, document)

                # Save source to database
                create_source(uploaded_file.name, "", chat_id, source_type="document")

                # Remove temp file
                os.remove(temp_file_path)
                del st.session_state["file_uploader"]

                st.rerun()

Enter fullscreen mode Exit fullscreen mode

The above code lists documents that were previously added to the vector store if there are any, and allows the user to delete them also.

When a user uploads a document, a spinner is shown whilst the document is added to the vector database. A temporary file directory is created if it is not created already and the file is saved to that directory. The document is then loaded using the load_document function, which returns the file as a list of Document objects. A collection is created or loaded, depending on whether it had been created before or not. If it had been created before, the documents are added to that collection using the add_documents_to_collection method. The create_source method is used to save the information in the relational database and finally the file is deleted from the temporary directory.

Now let us complete the function by adding link processing. In the chat_page function, add the following code:


def chat_page(chat_id):
    # rest of the code

    with st.sidebar:
        # rest of the code...

        # Links Section
        st.subheader("🔗 Links")

        # Display list of links
        links = list_sources(chat_id, source_type="link")

        if links:
            for link in links:
                link_id = link[0]
                link_url = link[1]
                col1, col2 = st.columns([0.8, 0.2])
                with col1:
                    st.markdown(f"[{link_url}]({link_url})")

                with col2:
                    if st.button("❌    ", key=f"delete_link_{link_id}"):
                        delete_source(link_id)
                        st.success(f"Deleted link: {link_url}")
                        st.rerun()
        else:
            st.write("No links added.")

        # Add new link
        new_link = st.text_input("Add a link", key="new_link")
        if st.button("Add Link", key="add_link_btn"):
            if new_link:
                with st.spinner("Processing link..."):
                    # Fetch content from the link
                    try:
                        headers = {

                            "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/86.0.4240.111 Safari/537.36"
                        }
                        response = requests.get(new_link, headers=headers)
                        soup = BeautifulSoup(response.text, "html.parser")

                        # Check if the content was successfully retrieved
                        if response.status_code == 200 and soup.text.strip():
                            link_content = soup.get_text(separator="\n")
                        else:
                            st.toast(
                                "Unable to retrieve content from the link. It may be empty or inaccessible.",
                                icon="🚨",
                            )
                            return

                        # Save link content to vector store
                        documents = [
                            Document(
                                page_content=link_content, metadata={"source": new_link}
                            )
                        ]

                        collection_name = f"chat_{chat_id}"
                        
                        if not os.path.exists(f"./persist"):
                            create_collection(collection_name, documents)
                        else:
                            vectordb = load_collection(collection_name)
                            add_documents_to_collection(vectordb, documents)



                        # Save link to database
                        create_source(new_link, "", chat_id, source_type="link")
                        st.success(f"Added link: {new_link}")
                        del st.session_state["add_link_btn"]
                        st.rerun()
                    except Exception as e:
                        st.toast(

                            f"Failed to fetch content from the link: {e}", icon="⚠️"
                        )
            else:
                st.toast("Please enter a link", icon="")

Enter fullscreen mode Exit fullscreen mode

The links are listed just like the documents. If a link is pasted into the input and the user clicks the Add Link button, a spinner is shown to show that the request is being processed. Using the requests library, the code attempts to fetch the content of the webpage. The User-Agent header is used to mimic a browser request since most websites will deny access to their content otherwise.

BeautifulSoup is then used to extract the content from the response. The code then checks if the response was successful (status code 200) and if the content is not empty. If not, a toast message is shown indicating that the link may be empty or inaccessible. If the content is valid, it is then saved as a Document object in a list with metadata indicating the source URL and just like before, it is added to the collection.

The full chat_page function looks as follows:

def chat_page(chat_id):

    """
    Display the chat page for a specific chat ID.

    This function handles displaying and managing an individual chat conversation, including:
    - Showing the chat history
    - Allowing users to send new messages
    - Streaming AI responses
    - Managing chat context through a vector store retriever
    Args:
        chat_id (int): The ID of the chat to display
    Returns:
        None
    """
    chat = read_chat(chat_id)
    if not chat:
        st.error("Chat not found")
        return

    # Retrieve messages from DB
    messages = get_messages(chat_id)

    # Display messages
    if messages:
        for sender, content in messages:
            if sender == "user":
                with st.chat_message("user"):
                    st.markdown(content)
            elif sender == "ai":
                with st.chat_message("assistant"):
                    st.markdown(content)
    else:
        st.write("No messages yet. Start the conversation!")

    # Add a text input for new messages
    prompt = st.chat_input("Type your message here...")
    if prompt:
        # Save user message
        create_message(chat_id, "user", prompt)

        # Display user message
        with st.chat_message("user"):
            st.markdown(prompt)

        # Get AI response
        # Load retriever for the chat context
        collection_name = f"chat_{chat_id}"
        if os.path.exists(f"./persist"):
            retriever = load_retriever(collection_name=collection_name)
        else:
            retriever = None

        # Ask question using the retriever
        response = (
            generate_answer_from_context(retriever, prompt)
            if retriever
            else "I need some context to answer that question."
        )

        # Save AI response
        create_message(chat_id, "ai", response)

        # Display AI response
        with st.chat_message("assistant"):
            st.write_stream(stream_response(response))

        st.rerun()

    # Sidebar for context
    with st.sidebar:
        # Button to return to the main chats page
        if st.button("Back to Chats"):
            st.query_params.clear()
            st.rerun()

        st.subheader(f"{chat[1]}")

        # Documents Section
        st.subheader("📑 Documents")

        # Display list of documents
        documents = list_sources(chat_id, source_type="document")
        if documents:
            for doc in documents:
                doc_id = doc[0]
                doc_name = doc[1]
                col1, col2 = st.columns([0.8, 0.2])
                with col1:
                    st.write(doc_name)
                with col2:
                    if st.button("", key=f"delete_doc_{doc_id}"):
                        delete_source(doc_id)
                        st.success(f"Deleted document: {doc_name}")
                        st.rerun()
        else:
            st.write("No documents uploaded.")

        uploaded_file = st.file_uploader("Upload Document", key="file_uploader")

        if uploaded_file:
            # Save document content to database
            with st.spinner("Processing document..."):
                temp_dir = "temp_files"
                os.makedirs(temp_dir, exist_ok=True)
                temp_file_path = os.path.join(temp_dir, uploaded_file.name)
                with open(temp_file_path, "wb") as f:
                    f.write(uploaded_file.getbuffer())

                # Load document
                document = load_document(temp_file_path)

                # Create or update collection for this chat
                collection_name = f"chat_{chat_id}"
                if not os.path.exists(f"./persist/{collection_name}"):
                    vectordb = create_collection(collection_name, document)
                else:
                    vectordb = load_collection(collection_name)
                    vectordb = add_documents_to_collection(vectordb, document)

                # Save source to database
                create_source(uploaded_file.name, "", chat_id, source_type="document")

                # Remove temp file
                os.remove(temp_file_path)
                del st.session_state["file_uploader"]
                st.rerun()

        # Links Section
        st.subheader("🔗 Links")

        # Display list of links
        links = list_sources(chat_id, source_type="link")

        if links:
            for link in links:
                link_id = link[0]
                link_url = link[1]
                col1, col2 = st.columns([0.8, 0.2])
                with col1:
                    st.markdown(f"[{link_url}]({link_url})")
                with col2:
                    if st.button("❌    ", key=f"delete_link_{link_id}"):
                        delete_source(link_id)
                        st.success(f"Deleted link: {link_url}")
                        st.rerun()
        else:
            st.write("No links added.")

        # Add new link
        new_link = st.text_input("Add a link", key="new_link")
        if st.button("Add Link", key="add_link_btn"):
            if new_link:
                with st.spinner("Processing link..."):
                    # Fetch content from the link
                    try:
                        headers = {

                            "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/86.0.4240.111 Safari/537.36"
                        }
                        response = requests.get(new_link, headers=headers)
                        soup = BeautifulSoup(response.text, "html.parser")

                        # Check if the content was successfully retrieved
                        if response.status_code == 200 and soup.text.strip():
                            link_content = soup.get_text(separator="\n")
                        else:
                            st.toast(
                                "Unable to retrieve content from the link. It may be empty or inaccessible.",
                                icon="🚨",
                            )
                            return

                        # Save link content to vector store
                        documents = [
                            Document(
                                page_content=link_content, metadata={"source": new_link}
                            )
                        ]
                        collection_name = f"chat_{chat_id}"

                        if not os.path.exists(f"./persist"):
                            create_collection(collection_name, documents)
                        else:
                            vectordb = load_collection(collection_name)
                            add_documents_to_collection(vectordb, documents)

                        # Save link to database
                        create_source(new_link, "", chat_id, source_type="link")
                        st.success(f"Added link: {new_link}")
                        del st.session_state["add_link_btn"]
                        st.rerun()
                    except Exception as e:
                        st.toast(
                            f"Failed to fetch content from the link: {e}", icon="⚠️"
                        )
            else:
                st.toast("Please enter a link", icon="")

Enter fullscreen mode Exit fullscreen mode

Now modify the main function as follows:

def main():
    """
    Main entry point for the chat application.
    Handles routing between the chats list page and individual chat pages:
    - If a chat_id is present in URL parameters, displays that specific chat
    - Otherwise shows the main chats listing page

    The function uses Streamlit query parameters to maintain state between page loads
    and determine which view to display.
    """

    query_params = st.query_params
    if "chat_id" in query_params:
        chat_id = query_params["chat_id"]
        chat_page(chat_id)
    else:
        chats_home()

if __name__ == "__main__":
    main()
Enter fullscreen mode Exit fullscreen mode

What the main function is simply doing is checking if the chat_id is present in the URL parameter. If it is, it calls chat_page and displays the specific chat. If not, it simply shows the main chat listings page.

With that, the application is complete!!!

Run the app with:
streamlit run chats.py

Conclusion

RAG applications opens up possibilities for building really intelligent, context aware systems. Whether you're looking to build a knowledge-based chatbot, a document search tool, or any application requiring contextual responses, RAG is an ideal approach.

I hope you have found this article helpful!

Top comments (0)