DEV Community

Tsubasa Kanno
Tsubasa Kanno

Posted on

Image search with Streamlit in Snowflake (SiS) Part 3 - Add vector search function for images -

Introduction

This article is Part 3. It's a continuation of Part 1 and Part 2, so if you haven't read them yet, please check out the following links first:

https://dev.to/tsubasa_tech/image-search-with-streamlit-in-snowflake-sis-part-1-creating-an-image-gallery-app--1bai

https://dev.to/tsubasa_tech/image-search-with-streamlit-in-snowflake-sis-part-2-automatically-generate-image-captions--2go5

In Part 1, we created an image gallery app using Streamlit in Snowflake to display images stored in the app's default internal stage. In Part 2, we added a feature to generate captions for each image based on the image gallery app.

Finally, in Part 3, we'll complete the application by adding an image search feature! We'll use vector search for image searching, allowing for relevant search results even with ambiguous keywords.

Note: This article is my personal publication. Please understand that it does not represent official statements from Snowflake.

Feature Overview

Goals

  • (Done) Display image data with Streamlit in Snowflake
  • (Done) Add descriptions to images with Streamlit in Snowflake
  • *Generate vector data based on image descriptions
  • *Perform image searches with Streamlit in Snowflake

*: Areas to be implemented in Part 3

Features to be Implemented in Part 3

  • Function to generate vector data from image captions
  • Function to perform fuzzy searches on the image gallery

Final Image for Part 3

Image description

Image description

Image description

Image description

Image description

Image description

Prerequisites

  • Snowflake
    • A Snowflake account
    • Streamlit in Snowflake installation package
      • boto3 1.28.64
  • AWS
    • An AWS account with access to Amazon Bedrock (we'll be using the us-east-1 region in this guide to use Claude 3.5 Sonnet)

Basic Confirmation

Vectorization Options in Snowflake

In this article, we'll implement a search function by creating vector data from image captions. While vectorization might seem conceptually and technically challenging, Snowflake allows for easy implementation of vectorization and vector search. I hope this article will help you realize that "Vector search is easier than I thought and quite useful!"

For details on Snowflake's vectorization methods and performance, please refer to my separate article:

https://zenn.dev/tsubasa_tech/articles/c0a2b8793a5d1f

Steps

(Omitted) Create a Streamlit in Snowflake app and upload images

If you haven't done this yet, please follow the steps in Part 1's article first.

(Omitted) Enable access to Amazon Bedrock from the Streamlit in Snowflake app

To automatically add captions to images, there are several options to consider:

  1. Implement processing using Python image processing libraries
  2. Create an ML model for image recognition to generate captions for images
  3. Use existing AI models for images like BLIP-2 to generate captions
  4. Pass images to a multimodal GenAI to generate captions
  5. Use a SaaS service for caption generation

Since we've previously introduced how to connect to Amazon Bedrock, we'll use Amazon Bedrock's anthropic.claude-3-5-sonnet as a multimodal GenAI (option 4) to generate image captions.

For instructions on setting up access to Amazon Bedrock, please refer to the Calling Amazon Bedrock directly from Streamlit in Snowflake (SiS).

Run the Streamlit in Snowflake App

In the Streamlit in Snowflake app editing screen, simply copy and paste the following code:

import streamlit as st
import pandas as pd
import os
import base64
import boto3
import json
from snowflake.snowpark.context import get_active_session
from snowflake.snowpark.functions import col, when_matched, when_not_matched, lit, call_udf
import _snowflake
from PIL import Image
import io

# Set custom theme
st.set_page_config(
    page_title="Image Gallery",
    layout="wide",
    initial_sidebar_state="expanded",
)

# Add custom CSS
st.markdown("""
<style>
    .reportview-container {
        background: #f0f2f6;
    }
    .main .block-container {
        padding-top: 2rem;
        padding-bottom: 2rem;
        padding-left: 5rem;
        padding-right: 5rem;
    }
    .stButton>button {
        background-color: #4CAF50;
        color: white;
        padding: 10px 20px;
        border: none;
        border-radius: 5px;
        cursor: pointer;
        transition: background-color 0.3s;
    }
    .stButton>button:hover {
        background-color: #45a049;
    }
    .stTextInput>div>div>input {
        border-radius: 5px;
    }
    .stSelectbox>div>div>select {
        border-radius: 5px;
    }
    h1, h2, h3 {
        color: #2c3e50;
    }
    .stProgress > div > div > div > div {
        background-color: #4CAF50;
    }
</style>
""", unsafe_allow_html=True)

# Image folder path
IMAGE_FOLDER = "image"

# Get Snowflake session
session = get_active_session()

# Create table (only on first run)
@st.cache_resource
def create_table_if_not_exists():
    session.sql("""
    CREATE TABLE IF NOT EXISTS IMAGE_METADATA (
        FILE_NAME STRING,
        DESCRIPTION STRING,
        VECTOR VECTOR(FLOAT, 1024)
    )
    """).collect()

create_table_if_not_exists()

# Function to get AWS credentials
def get_aws_credentials():
    aws_key_object = _snowflake.get_username_password('bedrock_key')
    region = 'us-east-1'
    return {
        'aws_access_key_id': aws_key_object.username,
        'aws_secret_access_key': aws_key_object.password,
        'region_name': region
    }, region

# Set up Bedrock client
boto3_session_args, region = get_aws_credentials()
boto3_session = boto3.Session(**boto3_session_args)
bedrock = boto3_session.client('bedrock-runtime', region_name=region)

# Get image data
@st.cache_data
def get_image_data():
    image_files = [f for f in os.listdir(IMAGE_FOLDER) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.gif'))]
    return [{"FILE_NAME": f, "IMG_PATH": os.path.join(IMAGE_FOLDER, f)} for f in image_files]

# Get metadata
@st.cache_data
def get_metadata():
    return session.table("IMAGE_METADATA").select("FILE_NAME", "DESCRIPTION").to_pandas()

# Convert image to thumbnail and encode in base64
@st.cache_data
def get_thumbnail_base64(img_path, max_size=(300, 300)):
    with Image.open(img_path) as img:
        img.thumbnail(max_size)
        buffered = io.BytesIO()
        img.save(buffered, format="JPEG")
        return base64.b64encode(buffered.getvalue()).decode('utf-8')

# Initialize image data and metadata
if 'img_df' not in st.session_state:
    st.session_state.img_df = get_image_data()
if 'metadata_df' not in st.session_state:
    st.session_state.metadata_df = get_metadata()

# Display image gallery
def show_image_gallery():
    st.title("🖼️ Image Gallery")

    # Add search box
    search_query = st.text_input("Search images (top 10 most relevant results will be displayed)", "")

    if search_query:
        # Escape search query (basic SQL injection prevention)
        escaped_query = search_query.replace("'", "''")

        # Vectorize search query and calculate similarity
        search_results = session.sql(f"""
        WITH search_vector AS (
            SELECT SNOWFLAKE.CORTEX.EMBED_TEXT_1024('voyage-multilingual-2', '{escaped_query}') as embedding
        )
        SELECT 
            i.FILE_NAME, 
            i.DESCRIPTION, 
            VECTOR_COSINE_SIMILARITY(i.VECTOR, s.embedding) as similarity
        FROM 
            IMAGE_METADATA i, 
            search_vector s
        WHERE 
            i.VECTOR IS NOT NULL
        ORDER BY 
            similarity DESC
        LIMIT 10
        """).collect()

        # Display search results
        st.subheader("Search Results")
        for result in search_results:
            file_name = result['FILE_NAME']
            description = result['DESCRIPTION']
            similarity = result['SIMILARITY']

            img_path = next((img['IMG_PATH'] for img in st.session_state.img_df if img['FILE_NAME'] == file_name), None)
            if img_path:
                col1, col2 = st.columns([1, 3])
                with col1:
                    st.image(img_path, width=150)
                with col2:
                    st.write(f"File name: {file_name}")
                    st.write(f"Description: {description}")
                    st.write(f"Match rate: {similarity:.1%}")
                st.markdown("---")
    else:
        # Normal gallery display
        num_columns = st.slider("Width:", min_value=1, max_value=5, value=4)
        cols = st.columns(num_columns)
        for i, img in enumerate(st.session_state.img_df):
            with cols[i % num_columns]:
                st.image(img["IMG_PATH"], caption=None, use_column_width=True)

# Edit image descriptions
def edit_image_descriptions():
    st.title("✏️ Edit Image Captions")
    st.session_state.metadata_df = get_metadata()

    # Add new images to metadata
    for img in st.session_state.img_df:
        if img["FILE_NAME"] not in st.session_state.metadata_df["FILE_NAME"].values:
            new_row = pd.DataFrame({"FILE_NAME": [img["FILE_NAME"]], "DESCRIPTION": [""]})
            st.session_state.metadata_df = pd.concat([st.session_state.metadata_df, new_row], ignore_index=True)

    merged_df = pd.merge(st.session_state.metadata_df, pd.DataFrame(st.session_state.img_df), on="FILE_NAME", how="left")

    with st.form("edit_descriptions"):
        for _, row in merged_df.iterrows():
            col1, col2 = st.columns([1, 3])
            with col1:
                st.image(row["IMG_PATH"], width=100)
            with col2:
                new_description = st.text_input(f"File name: {row['FILE_NAME']}", value=row["DESCRIPTION"], key=row['FILE_NAME'])
                merged_df.loc[merged_df["FILE_NAME"] == row["FILE_NAME"], "DESCRIPTION"] = new_description

        submit_button = st.form_submit_button("Save Changes")

    if submit_button:
        update_snowflake_table(merged_df[['FILE_NAME', 'DESCRIPTION']])
        st.success("Changes saved successfully!")
        st.cache_data.clear()
        st.session_state.metadata_df = get_metadata()

# Function to generate image description
def generate_description(image_path):
    image_base64 = get_thumbnail_base64(image_path)
    prompt = """
    Please describe this image in English within 400 characters in a single line.
    No need for a response, just output the image description.
    """

    request_body = {
        "anthropic_version": "bedrock-2023-05-31",
        "max_tokens": 200000,
        "messages": [
            {
                "role": "user",
                "content": [
                    {
                        "type": "image",
                        "source": {
                            "type": "base64",
                            "media_type": "image/jpeg",
                            "data": image_base64
                        }
                    },
                    {
                        "type": "text",
                        "text": prompt
                    }
                ]
            }
        ]
    }

    response = bedrock.invoke_model(
        body=json.dumps(request_body),
        modelId="anthropic.claude-3-5-sonnet-20240620-v1:0",
        accept='application/json',
        contentType='application/json'
    )

    response_body = json.loads(response.get('body').read())
    return response_body["content"][0]["text"]

# Function to update Snowflake table
def update_snowflake_table(update_df):
    snow_df = session.create_dataframe(update_df)

    session.table("IMAGE_METADATA").merge(
        snow_df,
        (session.table("IMAGE_METADATA").FILE_NAME == snow_df.FILE_NAME),
        [
            when_matched().update({
                "DESCRIPTION": snow_df.DESCRIPTION
            }),
            when_not_matched().insert({
                "FILE_NAME": snow_df.FILE_NAME,
                "DESCRIPTION": snow_df.DESCRIPTION
            })
        ]
    )

# Generate image descriptions
def generate_image_descriptions():
    st.title("🤖 Automatic Image Caption Generation")

    if 'generated_description' not in st.session_state:
        st.session_state.generated_description = None
    if 'selected_image' not in st.session_state:
        st.session_state.selected_image = None

    # Generate description for individual image
    with st.form("generate_description"):
        selected_image = st.selectbox("Select an image:", options=[img["FILE_NAME"] for img in st.session_state.img_df])
        generate_button = st.form_submit_button("Generate Image Caption")

    if generate_button:
        image_info = next(img for img in st.session_state.img_df if img['FILE_NAME'] == selected_image)
        generated_description = generate_description(image_info['IMG_PATH'])

        st.session_state.generated_description = generated_description
        st.session_state.selected_image = selected_image

        st.image(image_info['IMG_PATH'], width=300)
        st.write("Generated Caption:")
        st.write(generated_description)

    if st.session_state.generated_description is not None:
        if st.button("Save Caption"):
            update_snowflake_table(pd.DataFrame({'FILE_NAME': [st.session_state.selected_image], 'DESCRIPTION': [st.session_state.generated_description]}))
            st.success("Caption saved successfully")
            st.cache_data.clear()
            st.session_state.metadata_df = get_metadata()

            st.session_state.generated_description = None
            st.session_state.selected_image = None

    # Batch process images without descriptions
    st.subheader("Batch Caption Generation for Uncaptioned Images")

    images_without_description = [
        img for img in st.session_state.img_df 
        if img["FILE_NAME"] not in st.session_state.metadata_df[
            st.session_state.metadata_df["DESCRIPTION"].notna() & 
            (st.session_state.metadata_df["DESCRIPTION"] != "")
        ]["FILE_NAME"].values
    ]

    if images_without_description:
        st.write(f"{len(images_without_description)} images don't have captions.")
        if st.button("Generate Captions in Batch"):
            progress_bar = st.progress(0)
            for i, img in enumerate(images_without_description):
                generated_description = generate_description(img['IMG_PATH'])
                update_snowflake_table(pd.DataFrame({'FILE_NAME': [img['FILE_NAME']], 'DESCRIPTION': [generated_description]}))
                progress_bar.progress((i + 1) / len(images_without_description))

            st.success("Captions generated and saved for all images!")
            st.cache_data.clear()
            st.session_state.metadata_df = get_metadata()
    else:
        st.write("All images have captions.")

    # Display debug information
    st.subheader("Metadata Information")
    st.write(st.session_state.metadata_df)

# Function to generate vector data using Cortex LLM's Embedding function
def generate_embedding(text):
    if text and text.strip():
        result = session.sql(f"SELECT SNOWFLAKE.CORTEX.EMBED_TEXT_1024('voyage-multilingual-2', '{text}') as embedding").collect()
        return result[0]['EMBEDDING']
    return None

# Function to generate and save vector data
def generate_and_save_vectors():
    st.title("🧬 Automatic Vector Data Generation")

    # Get metadata (including vector data information)
    full_metadata = session.table("IMAGE_METADATA").select("FILE_NAME", "DESCRIPTION", "VECTOR").to_pandas()

    # Extract images without vector data
    images_without_vector = full_metadata[
        (full_metadata['DESCRIPTION'].notna()) & 
        (full_metadata['DESCRIPTION'] != "") & 
        (full_metadata['VECTOR'].isna())  # Only rows without vector data
    ]

    if images_without_vector.empty:
        st.write("All images have vector data.")
    else:
        st.write(f"Vector data can be generated for {len(images_without_vector)} images.")
        if st.button("Generate Vector Data"):
            progress_bar = st.progress(0)
            for i, (_, row) in enumerate(images_without_vector.iterrows()):
                params_df = session.create_dataframe([[row['DESCRIPTION'], row['FILE_NAME']]], schema=["description", "file_name"])

                session.sql("""
                UPDATE IMAGE_METADATA
                SET VECTOR = SNOWFLAKE.CORTEX.EMBED_TEXT_1024('voyage-multilingual-2', description)
                WHERE FILE_NAME = file_name AND VECTOR IS NULL
                """).join(params_df).collect()

                progress_bar.progress((i + 1) / len(images_without_vector))

            st.success("Vector data generated and saved for all target images!")
            st.cache_data.clear()

    # Display debug information
    st.subheader("Metadata Information")
    updated_full_metadata = session.table("IMAGE_METADATA").select("FILE_NAME", "DESCRIPTION", "VECTOR").to_pandas()
    st.write(updated_full_metadata)

# Main application execution
if __name__ == "__main__":
    st.sidebar.title("Navigation")
    page = st.sidebar.radio(
        "Select a feature to use:", 
        ["Image Gallery", "Edit Captions", "Auto-generate Captions", "Auto-generate Vector Data"]
    )

    if page == "Image Gallery":
        show_image_gallery()
    elif page == "Edit Captions":
        edit_image_descriptions()
    elif page == "Auto-generate Captions":
        generate_image_descriptions()
    elif page == "Auto-generate Vector Data":
        generate_and_save_vectors()
Enter fullscreen mode Exit fullscreen mode

Explanation of Some Code Parts

The following section applies custom CSS to slightly customize the design. Streamlit in Snowflake allows customization of web applications using HTML, CSS, and JavaScript. For more details, please check this documentation.

# Add custom CSS
st.markdown("""
<style>
    .reportview-container {
        background: #f0f2f6;
    }
    .main .block-container {
        padding-top: 2rem;
        padding-bottom: 2rem;
        padding-left: 5rem;
        padding-right: 5rem;
    }
    .stButton>button {
        background-color: #4CAF50;
        color: white;
        padding: 10px 20px;
        border: none;
        border-radius: 5px;
        cursor: pointer;
        transition: background-color 0.3s;
    }
    .stButton>button:hover {
        background-color: #45a049;
    }
    .stTextInput>div>div>input {
        border-radius: 5px;
    }
    .stSelectbox>div>div>select {
        border-radius: 5px;
    }
    h1, h2, h3 {
        color: #2c3e50;
    }
    .stProgress > div > div > div > div {
        background-color: #4CAF50;
    }
</style>
""", unsafe_allow_html=True)
Enter fullscreen mode Exit fullscreen mode

The following part vectorizes the user's search string and calculates its similarity with the vector data of the images. The higher the similarity, the closer the search string is to the image, so we retrieve the top 10 results based on similarity.

        # Vectorize the search query and calculate similarity
        search_results = session.sql(f"""
        WITH search_vector AS (
            SELECT SNOWFLAKE.CORTEX.EMBED_TEXT_1024('voyage-multilingual-2', '{escaped_query}') as embedding
        )
        SELECT 
            i.FILE_NAME, 
            i.DESCRIPTION, 
            VECTOR_COSINE_SIMILARITY(i.VECTOR, s.embedding) as similarity
        FROM 
            IMAGE_METADATA i, 
            search_vector s
        WHERE 
            i.VECTOR IS NOT NULL
        ORDER BY 
            similarity DESC
        LIMIT 10
        """).collect()
Enter fullscreen mode Exit fullscreen mode

The following section generates vector data from image captions. Note how this is achieved with a simple 3-line SQL query.

        if st.button("Generate Vector Data"):
            progress_bar = st.progress(0)
            for i, (_, row) in enumerate(images_without_vector.iterrows()):
                params_df = session.create_dataframe([[row['DESCRIPTION'], row['FILE_NAME']]], schema=["description", "file_name"])

                session.sql("""
                UPDATE IMAGE_METADATA
                SET VECTOR = SNOWFLAKE.CORTEX.EMBED_TEXT_1024('voyage-multilingual-2', description)
                WHERE FILE_NAME = file_name AND VECTOR IS NULL
                """).join(params_df).collect()

                progress_bar.progress((i + 1) / len(images_without_vector))

            st.success("Vector data generated and saved for all target images!")
            st.cache_data.clear()
Enter fullscreen mode Exit fullscreen mode

Conclusion

With the foundation for utilizing images by generating captions laid in previous parts, this article has enabled searching images even with ambiguous keywords. Using Snowflake's powerful vector search mechanism, it's possible to retrieve search results instantly even with vast amounts of image data.

Some ideas for further developing this app include:

  • Enabling similar image searches based on a specified image, not just user search strings
  • Extending search capabilities to other unstructured data types like documents and music

I'm sure you have many ideas of your own! I'd be delighted if you could use this article as a reference to bring your ideas to life.

Promotion

Sharing Snowflake's What's New on X

I'm sharing updates on Snowflake's What's New on X. I'd be happy if you could follow:

English Version

Snowflake What's New Bot (English Version)

Japanese Version

Snowflake's What's New Bot (Japanese Version)

Change History

(20240924) Initial post

Original Japanese Article

https://zenn.dev/tsubasa_tech/articles/64722ba45947a9

Top comments (0)