DEV Community

Tsubasa Kanno
Tsubasa Kanno

Posted on

Image search with Streamlit in Snowflake (SiS) Part 2 - Automatically generate image captions -

Introduction

This article is Part 2 of a series. If you haven't read Part 1 yet, please check it out first:

Part 1: Creating an Image Gallery using Streamlit in Snowflake

In Part 1, we created an image gallery app using Streamlit in Snowflake that displayed images stored in the app's default internal stage. In Part 2, we'll build upon that image gallery app to generate captions for each image, making it easier to utilize unstructured data.

Note: This article represents my personal views and not those of Snowflake.

Feature Overview

Goals

  • (Done) Display image data using Streamlit in Snowflake
  • *Add captions to images using Streamlit in Snowflake
  • Generate vector data based on image captions
  • Implement image search using Streamlit in Snowflake

*: Scope for Part 2

Features to be Implemented in Part 2

  • Manually create and edit image captions
  • Automatically generate captions for individual images
  • Bulk generate captions for images without captions

Final Image of Part 2

Image description

Image description

Image description

Image description

Image description

Prerequisites

  • Snowflake
    • A Snowflake account
    • Streamlit in Snowflake installation package
      • boto3 1.28.64
  • AWS
    • An AWS account with access to Amazon Bedrock (we'll be using Claude 3.5 Sonnet in the us-east-1 region for this tutorial)

Steps

(Omitted) Create a Streamlit in Snowflake App and Upload Images

If you haven't done this yet, please follow the steps in Part 1.

(Omitted) Set Up Access to Amazon Bedrock from Streamlit in Snowflake

For instructions on setting up access to Amazon Bedrock, please refer to the Calling Amazon Bedrock directly from Streamlit in Snowflake (SiS).

Run the Streamlit in Snowflake App

Copy and paste the following code into the Streamlit in Snowflake app editor:

import streamlit as st
import pandas as pd
import os
import base64
import boto3
import json
from snowflake.snowpark.context import get_active_session
from snowflake.snowpark.functions import col, when_matched, when_not_matched
import _snowflake
from PIL import Image
import io

# Streamlit page configuration
st.set_page_config(layout="wide", page_title="Image Gallery")

# Image folder path
IMAGE_FOLDER = "image"

# Get Snowflake session
session = get_active_session()

# Create table if it doesn't exist (only on first run)
@st.cache_resource
def create_table_if_not_exists():
    session.sql("""
    CREATE TABLE IF NOT EXISTS IMAGE_METADATA (
        FILE_NAME STRING,
        DESCRIPTION STRING,
        VECTOR VECTOR(FLOAT, 1024)
    )
    """).collect()

create_table_if_not_exists()

# Function to get AWS credentials
def get_aws_credentials():
    aws_key_object = _snowflake.get_username_password('bedrock_key')
    region = 'us-east-1'
    return {
        'aws_access_key_id': aws_key_object.username,
        'aws_secret_access_key': aws_key_object.password,
        'region_name': region
    }, region

# Set up Bedrock client
boto3_session_args, region = get_aws_credentials()
boto3_session = boto3.Session(**boto3_session_args)
bedrock = boto3_session.client('bedrock-runtime', region_name=region)

# Get image data
@st.cache_data
def get_image_data():
    image_files = [f for f in os.listdir(IMAGE_FOLDER) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.gif'))]
    return [{"FILE_NAME": f, "IMG_PATH": os.path.join(IMAGE_FOLDER, f)} for f in image_files]

# Get metadata
@st.cache_data
def get_metadata():
    return session.table("IMAGE_METADATA").select("FILE_NAME", "DESCRIPTION").to_pandas()

# Convert image to thumbnail and encode as base64
@st.cache_data
def get_thumbnail_base64(img_path, max_size=(300, 300)):
    with Image.open(img_path) as img:
        img.thumbnail(max_size)
        buffered = io.BytesIO()
        img.save(buffered, format="JPEG")
        return base64.b64encode(buffered.getvalue()).decode('utf-8')

# Initialize image data and metadata
if 'img_df' not in st.session_state:
    st.session_state.img_df = get_image_data()
if 'metadata_df' not in st.session_state:
    st.session_state.metadata_df = get_metadata()

# Display image gallery
def show_image_gallery():
    st.title("Image Gallery")
    num_columns = st.slider("Width:", min_value=1, max_value=5, value=4)
    cols = st.columns(num_columns)
    for i, img in enumerate(st.session_state.img_df):
        with cols[i % num_columns]:
            st.image(img["IMG_PATH"], caption=None, use_column_width=True)

# Edit image descriptions
def edit_image_descriptions():
    st.title("Edit Image Descriptions")
    st.session_state.metadata_df = get_metadata()

    # Add new images to metadata
    for img in st.session_state.img_df:
        if img["FILE_NAME"] not in st.session_state.metadata_df["FILE_NAME"].values:
            new_row = pd.DataFrame({"FILE_NAME": [img["FILE_NAME"]], "DESCRIPTION": [""]})
            st.session_state.metadata_df = pd.concat([st.session_state.metadata_df, new_row], ignore_index=True)

    merged_df = pd.merge(st.session_state.metadata_df, pd.DataFrame(st.session_state.img_df), on="FILE_NAME", how="left")

    with st.form("edit_descriptions"):
        for _, row in merged_df.iterrows():
            col1, col2 = st.columns([1, 3])
            with col1:
                st.image(row["IMG_PATH"], width=100)
            with col2:
                new_description = st.text_input(f"Description for {row['FILE_NAME']}", value=row["DESCRIPTION"], key=row['FILE_NAME'])
                merged_df.loc[merged_df["FILE_NAME"] == row["FILE_NAME"], "DESCRIPTION"] = new_description

        submit_button = st.form_submit_button("Save Changes")

    if submit_button:
        update_snowflake_table(merged_df[['FILE_NAME', 'DESCRIPTION']])
        st.success("Changes saved successfully!")
        st.cache_data.clear()
        st.session_state.metadata_df = get_metadata()

# Function to generate image description
def generate_description(image_path):
    image_base64 = get_thumbnail_base64(image_path)
    prompt = """
    Describe this image in English within 400 characters, in a single line.
    Only output the image description without any additional response.
    """

    request_body = {
        "anthropic_version": "bedrock-2023-05-31",
        "max_tokens": 200000,
        "messages": [
            {
                "role": "user",
                "content": [
                    {
                        "type": "image",
                        "source": {
                            "type": "base64",
                            "media_type": "image/jpeg",
                            "data": image_base64
                        }
                    },
                    {
                        "type": "text",
                        "text": prompt
                    }
                ]
            }
        ]
    }

    response = bedrock.invoke_model(
        body=json.dumps(request_body),
        modelId="anthropic.claude-3-5-sonnet-20240620-v1:0",
        accept='application/json',
        contentType='application/json'
    )

    response_body = json.loads(response.get('body').read())
    return response_body["content"][0]["text"]

# Function to update Snowflake table
def update_snowflake_table(update_df):
    snow_df = session.create_dataframe(update_df)

    session.table("IMAGE_METADATA").merge(
        snow_df,
        (session.table("IMAGE_METADATA").FILE_NAME == snow_df.FILE_NAME),
        [
            when_matched().update({
                "DESCRIPTION": snow_df.DESCRIPTION
            }),
            when_not_matched().insert({
                "FILE_NAME": snow_df.FILE_NAME,
                "DESCRIPTION": snow_df.DESCRIPTION
            })
        ]
    )

# Generate image descriptions
def generate_image_descriptions():
    st.title("Generate Image Descriptions")

    if 'generated_description' not in st.session_state:
        st.session_state.generated_description = None
    if 'selected_image' not in st.session_state:
        st.session_state.selected_image = None

    # Generate description for individual image
    with st.form("generate_description"):
        selected_image = st.selectbox("Select an image", options=[img["FILE_NAME"] for img in st.session_state.img_df])
        generate_button = st.form_submit_button("Generate Description")

    if generate_button:
        image_info = next(img for img in st.session_state.img_df if img['FILE_NAME'] == selected_image)
        generated_description = generate_description(image_info['IMG_PATH'])

        st.session_state.generated_description = generated_description
        st.session_state.selected_image = selected_image

        st.image(image_info['IMG_PATH'], width=300)
        st.write("Generated Description:")
        st.write(generated_description)

    if st.session_state.generated_description is not None:
        if st.button("Save Description"):
            update_snowflake_table(pd.DataFrame({'FILE_NAME': [st.session_state.selected_image], 'DESCRIPTION': [st.session_state.generated_description]}))
            st.success("Description saved successfully!")
            st.cache_data.clear()
            st.session_state.metadata_df = get_metadata()

            st.session_state.generated_description = None
            st.session_state.selected_image = None

    # Bulk process images without descriptions
    st.subheader("Bulk Process Images Without Descriptions")

    images_without_description = [
        img for img in st.session_state.img_df 
        if img["FILE_NAME"] not in st.session_state.metadata_df[
            st.session_state.metadata_df["DESCRIPTION"].notna() & 
            (st.session_state.metadata_df["DESCRIPTION"] != "")
        ]["FILE_NAME"].values
    ]

    if images_without_description:
        st.write(f"{len(images_without_description)} images do not have descriptions.")
        if st.button("Generate Descriptions in Bulk"):
            progress_bar = st.progress(0)
            for i, img in enumerate(images_without_description):
                generated_description = generate_description(img['IMG_PATH'])
                update_snowflake_table(pd.DataFrame({'FILE_NAME': [img['FILE_NAME']], 'DESCRIPTION': [generated_description]}))
                progress_bar.progress((i + 1) / len(images_without_description))

            st.success("Descriptions generated and saved for all images!")
            st.cache_data.clear()
            st.session_state.metadata_df = get_metadata()
    else:
        st.write("All images have descriptions.")

    # Display debug information
    st.subheader("Metadata Information")
    st.write(st.session_state.metadata_df)

# Main application execution
if __name__ == "__main__":
    page = st.sidebar.selectbox(
        "Choose a page", 
        ["Image Gallery", "Edit Descriptions", "Generate Descriptions"]
    )

    if page == "Image Gallery":
        show_image_gallery()
    elif page == "Edit Descriptions":
        edit_image_descriptions()
    elif page == "Generate Descriptions":
        generate_image_descriptions()
Enter fullscreen mode Exit fullscreen mode

Code Explanation

The following section creates a table to store image metadata. The third column, VECTOR, is intended to store vector data in Part 3:

# Create table if it doesn't exist (only on first run)
@st.cache_resource
def create_table_if_not_exists():
    session.sql("""
    CREATE TABLE IF NOT EXISTS IMAGE_METADATA (
        FILE_NAME STRING,
        DESCRIPTION STRING,
        VECTOR VECTOR(FLOAT, 1024)
    )
    """).collect()

create_table_if_not_exists()
Enter fullscreen mode Exit fullscreen mode

This function resizes the image and encodes it as Base64 when passing image data to Amazon Bedrock. This is done to improve caption generation performance and because Amazon Bedrock likely cannot accept raw binary image data:

# Convert image to thumbnail and encode as base64
@st.cache_data
def get_thumbnail_base64(img_path, max_size=(300, 300)):
    with Image.open(img_path) as img:
        img.thumbnail(max_size)
        buffered = io.BytesIO()
        img.save(buffered, format="JPEG")
        return base64.b64encode(buffered.getvalue()).decode('utf-8')
Enter fullscreen mode Exit fullscreen mode

This section passes the caption generation prompt to Amazon Bedrock. The prompt may need to be adjusted in Part 3 to improve search accuracy:

# Function to generate image description
def generate_description(image_path):
    image_base64 = get_thumbnail_base64(image_path)
    prompt = """
    Describe this image in English within 400 characters, in a single line.
    Only output the image description without any additional response.
    """
Enter fullscreen mode Exit fullscreen mode

Conclusion

With these improvements, we can now automatically generate captions for images in our gallery. While our ultimate goal is to implement image search, having captions associated with images opens up many possibilities for utilizing image data.

In Part 3, we'll implement image search by vectorizing the image captions and performing vector searches. Stay tuned!

Follow Snowflake What's New on Twitter

For updates on Snowflake's What's New, follow these Twitter accounts:

Change Log

(20240923) Initial post

Original Japanese Article

https://zenn.dev/tsubasa_tech/articles/b6b9928d33ae58

Top comments (0)