DEV Community

Tsubasa Kanno
Tsubasa Kanno

Posted on

1

Building a Flow Diagram Auto-Generation App with Streamlit in Snowflake (SiS) and Cortex AI

Introduction

In this article, I would like to introduce an application I created that combines Streamlit in Snowflake with Cortex AI to automatically generate flow diagrams and architecture diagrams. I may write a Part 2 article in the future as there are still areas I'd like to improve!

📝 Note: This article represents my personal views and not those of Snowflake.

What is Streamlit in Snowflake (SiS)?

First, Streamlit is a Python library. With Streamlit, you can create a web UI with simple Python code. Normally, web UIs require HTML/CSS/JavaScript, but Streamlit replaces these functionalities.

You can get a better idea by looking at examples in the App Gallery where users have posted samples.

Streamlit in Snowflake allows you to develop and run Streamlit web apps on Snowflake. The convenience of only needing a Snowflake account is attractive, but I think the best part is being able to easily incorporate Snowflake table data into web applications.

About Streamlit in Snowflake (Snowflake Official Documentation)

What is Snowflake Cortex?

Snowflake Cortex is a suite of generative AI functions in Snowflake. Within this suite, Cortex LLM is a feature that allows you to call large language models running on Snowflake through simple functions from SQL or Python.

Large Language Model (LLM) Functions (Snowflake Cortex) (Snowflake Official Documentation)

Feature Overview

Feature List

In the application I created, I implemented the following features:

  • A graphical editor for creating flow diagrams and architecture diagrams
    • Interactive features for adding and editing nodes and edges
    • Customizable shapes and style settings
    • Graph generation from configured items
  • Template functionality
    • Graph generation from several templates
    • Feature to reflect template content in the graphical editor
  • Direct DOT code editing functionality
    • Direct input of Graphviz syntax in DOT code format
    • Graph generation from inputted DOT code
  • AI generation functionality
    • Automatic flow diagram generation from natural language descriptions

Completed Images

Overall application view

Graphical editor for nodes and edges

Node and edge list display

Graph and DOT code generated from the graphical editor

Template functionality 1/2

Template functionality 2/2

Direct DOT code input functionality

Automatic graph generation functionality with Cortex AI

Graph generated by Cortex AI

Prerequisites

  • Snowflake account
    • A Snowflake account that can use Cortex LLM (with the release of cross-region inference, there are almost no cloud or region constraints)
  • Streamlit in Snowflake installation packages
    • python 3.11 or later
    • snowflake-ml-python 1.7.4 or later
    • python-graphviz 0.20.1 or later

*Cortex LLM Regional Availability Table (Snowflake Official Documentation)

What is Graphviz?

Graphviz is an open-source library for visualizing graph structures. It's suitable for drawing directed and undirected graphs consisting of nodes and edges, such as network diagrams, flowcharts, organizational charts, ER diagrams, and various other illustrations.

Graphviz has the following features:

  • Uses a specialized description language called DOT code
  • Automatic layout functionality (automatically optimizes node placement)
  • Supports a variety of node shapes and edge styles

For this project, I'll be using the python-graphviz library which can be easily installed in Streamlit in Snowflake to create a flow diagram creation application.

Procedure

Create a new Streamlit in Snowflake application

Click on "Streamlit" in the left pane of Snowsight, then click the "+ Streamlit" button to create a SiS application.

Run the Streamlit in Snowflake application

In the Streamlit in Snowflake app editing screen, copy and paste the following code to complete the setup.

import streamlit as st
import graphviz
import pandas as pd
from snowflake.snowpark.context import get_active_session
from snowflake.cortex import Complete as CompleteText

# Application settings
st.set_page_config(
    layout="wide",
    initial_sidebar_state="expanded"
)

# Get Snowflake session
session = get_active_session()

# Application header
st.title("Streamlit in Snowflake Flow Diagram Creator")
st.markdown("A tool for creating architecture and data flow diagrams using Cortex AI and Graphviz.")

# Common function: Display nodes table
def display_nodes_table(nodes_list, is_template=False):
    if not nodes_list:
        st.info("No nodes have been added yet.")
        return

    nodes_data = {
        "Node ID": [],
        "Label": [],
        "Shape": [],
        "Style": [],
        "Fill Color": [],
        "Border Color": [],
        "Details": []
    }

    for node in nodes_list:
        # Convert template structure if needed
        if is_template:
            node_id = node["id"]
            node_label = node["label"]
            node_shape = node["attrs"].get("shape", "box")
            node_style = "filled"
            node_fillcolor = node["attrs"].get("fillcolor", "#D0E8FF")
            node_color = "#000000"

            nodes_data["Node ID"].append(node_id)
            nodes_data["Label"].append(node_label)
            nodes_data["Shape"].append(node_shape)
            nodes_data["Style"].append(node_style)
            nodes_data["Fill Color"].append(f"<span style='color:{node_fillcolor};'>■</span> {node_fillcolor}")
            nodes_data["Border Color"].append(f"<span style='color:#000000;'>■</span> #000000")
            nodes_data["Details"].append("-")
        else:
            nodes_data["Node ID"].append(node['id'])
            nodes_data["Label"].append(node['label'])
            nodes_data["Shape"].append(node['shape'])
            nodes_data["Style"].append(node['style'])
            nodes_data["Fill Color"].append(f"<span style='color:{node['fillcolor']};'>■</span> {node['fillcolor']}")
            nodes_data["Border Color"].append(f"<span style='color:{node['color']};'>■</span> {node['color']}")

            # Summarize detailed settings
            details = []
            if 'peripheries' in node and node['peripheries'] > 1:
                details.append(f"Peripheries: {node['peripheries']}")
            if 'fontname' in node and node['fontname'] != 'sans-serif':
                details.append(f"Font: {node['fontname']}")
            if 'fontsize' in node and node['fontsize'] != 14:
                details.append(f"Size: {node['fontsize']}")
            if 'tooltip' in node and node['tooltip']:
                details.append(f"Has tooltip")

            nodes_data["Details"].append(", ".join(details) if details else "-")

    nodes_df = pd.DataFrame(nodes_data)
    st.write(nodes_df.to_html(escape=False), unsafe_allow_html=True)

# Common function: Display edges table
def display_edges_table(edges_list, is_template=False):
    if not edges_list:
        st.info("No edges have been added yet.")
        return

    edges_data = {
        "Source": [],
        "Target": [],
        "Label": [],
        "Style": [],
        "Color": [],
        "Direction": [],
        "Arrow": [],
        "Details": []
    }

    for edge in edges_list:
        # Convert template structure if needed
        if is_template:
            source = edge["source"]
            target = edge["target"]
            label = edge.get("label", "-")

            # Get style information
            style = "solid"
            if "style" in edge["attrs"]:
                style = edge["attrs"]["style"]
            elif "penwidth" in edge["attrs"] and edge["attrs"]["penwidth"] == "2":
                style = "bold"

            color = edge["attrs"].get("color", "#666666")
            arrow = edge["attrs"].get("arrowhead", "normal")

            edges_data["Source"].append(source)
            edges_data["Target"].append(target)
            edges_data["Label"].append(label)
            edges_data["Style"].append(style)
            edges_data["Color"].append(f"<span style='color:{color};'>■</span> {color}")
            edges_data["Direction"].append("forward")
            edges_data["Arrow"].append(arrow)
            edges_data["Details"].append("-")
        else:
            edges_data["Source"].append(edge['source'])
            edges_data["Target"].append(edge['target'])
            edges_data["Label"].append(edge['label'] if edge['label'] else "-")

            style_text = edge['style'].split(" ")[0] if " " in edge['style'] else edge['style']
            edges_data["Style"].append(style_text)

            edges_data["Color"].append(f"<span style='color:{edge['color']};'>■</span> {edge['color']}")
            edges_data["Direction"].append(edge['dir'])
            edges_data["Arrow"].append(edge['arrow'])

            details = []
            if 'fontname' in edge and edge['fontname'] != 'sans-serif':
                details.append(f"Font: {edge['fontname']}")
            if 'fontsize' in edge and edge['fontsize'] != 10:
                details.append(f"Size: {edge['fontsize']}")
            if 'penwidth' in edge and edge['penwidth'] != 1.0:
                details.append(f"Width: {edge['penwidth']}")
            if 'weight' in edge and edge['weight'] != 1.0:
                details.append(f"Weight: {edge['weight']}")
            if 'minlen' in edge and edge['minlen'] != 1:
                details.append(f"Min Length: {edge['minlen']}")
            if 'tooltip' in edge and edge['tooltip']:
                details.append(f"Has tooltip")

            edges_data["Details"].append(", ".join(details) if details else "-")

    edges_df = pd.DataFrame(edges_data)
    st.write(edges_df.to_html(escape=False), unsafe_allow_html=True)

# Initialize session state
if 'active_tab' not in st.session_state:
    st.session_state['active_tab'] = "Nodes and Edges Input"
if 'nodes' not in st.session_state:
    st.session_state['nodes'] = []
if 'edges' not in st.session_state:
    st.session_state['edges'] = []
if 'template_applied' not in st.session_state:
    st.session_state['template_applied'] = False
if 'template_type' not in st.session_state:
    st.session_state['template_type'] = ""
if 'template_previewed' not in st.session_state:
    st.session_state['template_previewed'] = False
if 'current_template_data' not in st.session_state:
    st.session_state['current_template_data'] = None
if 'dot_code' not in st.session_state:
    st.session_state['dot_code'] = """digraph G {
    rankdir=LR;
    node [shape=box];
    A [label="Start"];
    B [label="Process 1"];
    C [label="Process 2"];
    D [label="End"];
    A -> B [label="Flow 1"];
    B -> C [label="Flow 2"];
    C -> D [label="Flow 3"];
}"""
if 'ai_prompt' not in st.session_state:
    st.session_state['ai_prompt'] = """Create a web application system architecture diagram.
Users access through browsers, connecting to web servers via a load balancer.
Web servers communicate with application servers, which store data in database servers.
The system also uses cache servers."""

# Callback function for template or style changes
def reset_template_preview():
    """Reset preview state when template or style is changed"""
    st.session_state['template_previewed'] = False

# Callback function for tab changes
def on_tab_change():
    """Handle tab change events"""
    # Reset template preview state
    st.session_state['template_previewed'] = False

# Function to manage template application state
def on_template_apply_click(template_nodes, template_edges, template_node_shape, template_node_color, template_edge_style, template_edge_color, template_type):
    """Handle template application button click"""
    # Save pending template information to session
    st.session_state['current_template_data'] = {
        'nodes': template_nodes,
        'edges': template_edges,
        'node_shape': template_node_shape,
        'node_color': template_node_color,
        'edge_style': template_edge_style,
        'edge_color': template_edge_color,
        'type': template_type
    }
    # Set template application flag
    st.session_state['template_applied'] = True
    # Switch active tab
    st.session_state['active_tab'] = "Nodes and Edges Input" 

# Function to apply template to nodes and edges input
def apply_template_to_session(template_nodes, template_edges, template_node_shape, template_node_color, template_edge_style, template_edge_color, template_type):
    """Apply template to nodes and edges input (direct session update)"""
    # Clear existing nodes and edges
    st.session_state['nodes'] = []
    st.session_state['edges'] = []

    # Add template nodes
    for node in template_nodes:
        st.session_state['nodes'].append({
            'id': node["id"],
            'label': node["label"],
            'shape': template_node_shape,
            'style': "filled",
            'fillcolor': template_node_color,
            'color': "#000000",
            'peripheries': 1,
            'fontname': "sans-serif",
            'fontsize': 14,
            'fontcolor': "#000000",
            'tooltip': ""
        })

    # Add template edges
    for edge in template_edges:
        edge_style = "solid (solid line)"
        if template_edge_style == "dashed":
            edge_style = "dashed (dashed line)"
        elif template_edge_style == "thick":
            edge_style = "bold (thick line)"
        elif template_edge_style == "arrow":
            edge_style = "solid (solid line)"

        # Get edge attributes
        arrow_head = "normal"
        if "arrowhead" in edge["attrs"]:
            arrow_head = edge["attrs"]["arrowhead"]

        st.session_state['edges'].append({
            'source': edge["source"],
            'target': edge["target"],
            'label': edge.get("label", ""),
            'style': edge_style,
            'color': template_edge_color,
            'dir': "forward",
            'arrow': arrow_head,
            'fontname': "sans-serif",
            'fontsize': 10,
            'fontcolor': "#000000",
            'tooltip': "",
            'penwidth': 2.0 if template_edge_style == "thick" else 1.0,
            'weight': 1.0,
            'constraint': True,
            'minlen': 1
        })

    # Set template information
    st.session_state['template_applied'] = True
    st.session_state['template_type'] = template_type
    # Switch tab
    st.session_state['active_tab'] = "Nodes and Edges Input"

# Sidebar: Basic settings
st.sidebar.title("Basic Settings")
graph_direction = st.sidebar.radio(
    "Graph Direction",
    ["Left to Right (LR)", "Top to Bottom (TB)"]
)

# Sidebar: AI settings
st.sidebar.title("Cortex AI Settings")
use_llm = st.sidebar.checkbox("Enable AI Generation", value=False)

if use_llm:
    lang_model = st.sidebar.radio("Select Cortex AI Model",
                                  ("deepseek-r1",
                                   "claude-3-5-sonnet",
                                   "mistral-large2", "mixtral-8x7b", "mistral-7b",
                                   "llama3.3-70b",
                                   "llama3.2-1b", "llama3.2-3b",
                                   "llama3.1-8b", "llama3.1-70b", "llama3.1-405b",
                                   "snowflake-llama-3.1-405b", "snowflake-llama-3.3-70b",
                                   "snowflake-arctic",
                                   "reka-flash", "reka-core",
                                   "jamba-instruct", "jamba-1.5-mini", "jamba-1.5-large",
                                   "gemma-7b",
                                   "mistral-large", "llama3-8b", "llama3-70b", "llama2-70b-chat"
                                  ),
                                  index=1)

# Main content area
st.header("Graph Creation")

# Input method tabs
tab_options = ["Nodes and Edges Input", "Templates", "Direct DOT Code Input", "AI Generation"]
selected_tab = st.radio("Select Input Method", tab_options, horizontal=True, key="tab_selector", on_change=on_tab_change)
st.session_state['active_tab'] = selected_tab 

# Tab 1: Nodes and Edges Input
if selected_tab == "Nodes and Edges Input":
    # Check if template was applied
    if st.session_state.get('template_applied'):
        st.success(f"Applied template '{st.session_state.get('template_type')}'. Nodes: {len(st.session_state['nodes'])}, Edges: {len(st.session_state['edges'])}")
        # Reset flag
        st.session_state['template_applied'] = False

    node_col, edge_col = st.columns(2)

    # Node input form
    with node_col:
        st.subheader("Node Definition")

        with st.form(key="node_form"):
            node_id = st.text_input("Node ID", placeholder="Example: node1, server1")
            node_label = st.text_input("Node Label", placeholder="Example: Server 1, Database")

            node_shape = st.selectbox(
                "Node Shape",
                ["box", "ellipse", "circle", "diamond", "plaintext", "polygon", "triangle", "hexagon", "cylinder", 
                 "folder", "component", "note", "tab", "house", "invhouse", "parallelogram", "record", "Mrecord"]
            )

            node_style = st.selectbox(
                "Node Style",
                ["filled", "dashed", "dotted", "solid", "filled,rounded", "dashed,filled", "dotted,filled", "bold", "invis"],
                index=0
            )

            node_fillcolor = st.color_picker("Node Fill Color", "#D0E8FF")
            node_color = st.color_picker("Node Border Color", "#000000")

            # Advanced settings
            with st.expander("Advanced Settings"):
                peripheries = st.number_input("Periphery Lines", min_value=1, max_value=10, value=1, help="Number of periphery lines around the node. 2 or more creates a double-line effect.")

                font_col1, font_col2 = st.columns(2)
                with font_col1:
                    fontname = st.selectbox("Font Name", 
                                            ["sans-serif", "serif", "Arial", "Helvetica", "Times-Roman", "Courier", "MS Gothic", "MS UI Gothic", "Meiryo"], 
                                            index=0)
                    fontcolor = st.color_picker("Font Color", "#000000")

                with font_col2:
                    fontsize = st.number_input("Font Size", min_value=8, max_value=72, value=14)
                    tooltip = st.text_input("Tooltip", placeholder="Text to display on mouseover")

            node_submit = st.form_submit_button("Add Node")

            if node_submit and node_id and node_label:
                # Check for existing node IDs
                existing_ids = [node['id'] for node in st.session_state['nodes']]
                if node_id in existing_ids:
                    st.error(f"Node ID '{node_id}' is already in use. Please choose a different ID.")
                else:
                    style_str = node_style

                    st.session_state['nodes'].append({
                        'id': node_id,
                        'label': node_label,
                        'shape': node_shape,
                        'style': style_str,
                        'fillcolor': node_fillcolor,
                        'color': node_color,
                        'peripheries': peripheries,
                        'fontname': fontname,
                        'fontsize': fontsize,
                        'fontcolor': fontcolor,
                        'tooltip': tooltip
                    })
                    st.success(f"Added node '{node_id}'.")

    # Edge input form
    with edge_col:
        st.subheader("Edge Definition")

        with st.form(key="edge_form"):
            dir_option = st.selectbox(
                "Edge Direction",
                ["forward (directed)", "back (directed reverse)", "both (bidirectional)", "none (undirected)"],
                index=0,
                help="Forward draws an arrow from source to target, back draws from target to source, both draws bidirectional arrows, none creates an undirected edge."
            )

            # Get list of node IDs
            node_ids = [node['id'] for node in st.session_state['nodes']]

            if not node_ids:
                st.warning("Please add nodes first.")
                source_node = ""
                target_node = ""
                source_select = st.selectbox("Source Node", ["Please add nodes first"])
                target_select = st.selectbox("Target Node", ["Please add nodes first"])
            else:
                source_select = st.selectbox("Source Node", node_ids)
                target_select = st.selectbox("Target Node", node_ids)
                source_node = source_select
                target_node = target_select

            edge_label = st.text_input("Edge Label", placeholder="Example: data flow, call")

            edge_style = st.selectbox(
                "Edge Basic Style",
                ["solid (solid line)", "dashed (dashed line)", "dotted (dotted line)", "bold (thick line)"]
            )

            edge_color = st.color_picker("Edge Color", "#666666")

            # Arrow shape selection (adjusted automatically based on direction)
            if dir_option.startswith("none"):
                arrow_shape = "none"
            else:
                arrow_shape = st.selectbox(
                    "Arrow Shape",
                    ["normal", "vee", "tee", "dot", "diamond", "box", "crow", "inv", "invdot", "odot", "open", "halfopen", "none"]
                )

            # Advanced settings
            with st.expander("Advanced Settings"):
                font_col1, font_col2 = st.columns(2)
                with font_col1:
                    edge_fontname = st.selectbox("Font Name", 
                                            ["sans-serif", "serif", "Arial", "Helvetica", "Times-Roman", "Courier", "MS Gothic", "MS UI Gothic", "Meiryo"], 
                                            index=0,
                                            key="edge_fontname")
                    edge_fontcolor = st.color_picker("Font Color", "#000000", key="edge_fontcolor")

                with font_col2:
                    edge_fontsize = st.number_input("Font Size", min_value=8, max_value=72, value=10, key="edge_fontsize")
                    edge_tooltip = st.text_input("Tooltip", placeholder="Text to display on mouseover", key="edge_tooltip")

                edge_penwidth = st.slider("Line Width", min_value=0.5, max_value=10.0, value=1.0, step=0.5)
                edge_weight = st.slider("Weight", min_value=0.1, max_value=10.0, value=1.0, step=0.1, 
                                       help="Higher values indicate more important connections")
                edge_constraint = st.checkbox("Maintain Hierarchy", value=True, 
                                           help="When checked, maintains hierarchical structure (top to bottom, left to right)")
                edge_minlen = st.number_input("Minimum Length", min_value=1, max_value=10, value=1, 
                                           help="Specifies the minimum edge length")

            edge_submit = st.form_submit_button("Add Edge")

            if edge_submit and source_node and target_node:
                # Check for duplicates
                duplicate = False
                for edge in st.session_state['edges']:
                    if edge['source'] == source_node and edge['target'] == target_node:
                        duplicate = True
                        break

                if duplicate:
                    st.error(f"Edge '{source_node} -> {target_node}' already exists.")
                else:
                    dir_value = dir_option.split(" ")[0]  # Extract "forward" from "forward (directed)"

                    st.session_state['edges'].append({
                        'source': source_node,
                        'target': target_node,
                        'label': edge_label,
                        'style': edge_style,
                        'color': edge_color,
                        'dir': dir_value,
                        'arrow': arrow_shape,
                        'fontname': edge_fontname,
                        'fontsize': edge_fontsize,
                        'fontcolor': edge_fontcolor,
                        'tooltip': edge_tooltip,
                        'penwidth': edge_penwidth,
                        'weight': edge_weight,
                        'constraint': edge_constraint,
                        'minlen': edge_minlen
                    })
                    st.success(f"Added edge '{source_node} -> {target_node}'.")

    # List of added nodes and edges
    st.subheader("List of Added Nodes and Edges")

    nodes_tab, edges_tab = st.tabs(["Nodes List", "Edges List"])

    # Display nodes list
    with nodes_tab:
        if not st.session_state['nodes']:
            st.info("No nodes have been added yet.")
        else:
            display_nodes_table(st.session_state['nodes'])

            # Node deletion feature
            with st.form(key="node_delete_form"):
                delete_node = st.selectbox("Node to Delete", [node['id'] for node in st.session_state['nodes']])
                node_delete_submit = st.form_submit_button("Delete Selected Node")

                if node_delete_submit:
                    node_to_delete = delete_node
                    st.session_state['nodes'] = [node for node in st.session_state['nodes'] if node['id'] != node_to_delete]

                    # Delete related edges
                    st.session_state['edges'] = [
                        edge for edge in st.session_state['edges'] 
                        if edge['source'] != node_to_delete and edge['target'] != node_to_delete
                    ]
                    st.success(f"Deleted node '{node_to_delete}' and its related edges.")
                    st.rerun()

    # Display edges list
    with edges_tab:
        if not st.session_state['edges']:
            st.info("No edges have been added yet.")
        else:
            display_edges_table(st.session_state['edges'])

            # Edge deletion feature
            with st.form(key="edge_delete_form"):
                edge_options = [f"{edge['source']} -> {edge['target']}" for edge in st.session_state['edges']]
                delete_edge = st.selectbox("Edge to Delete", edge_options)
                edge_delete_submit = st.form_submit_button("Delete Selected Edge")

                if edge_delete_submit:
                    source, target = delete_edge.split(" -> ")
                    st.session_state['edges'] = [
                        edge for edge in st.session_state['edges'] 
                        if not (edge['source'] == source and edge['target'] == target)
                    ]
                    st.success(f"Deleted edge '{delete_edge}'.")
                    st.rerun()

    # Clear all data button
    if st.button("Clear All"):
        st.session_state['nodes'] = []
        st.session_state['edges'] = []
        st.success("Cleared all nodes and edges.")
        st.rerun()

# Tab 2: Templates
elif selected_tab == "Templates":
    st.subheader("Select from Templates")
    template_type = st.selectbox(
        "Template",
        ["Simple System Architecture", "Microservice Architecture", "Data Pipeline", 
         "Network Configuration", "Cloud Architecture"],
        key="template_select",
        on_change=reset_template_preview
    )
    st.info(f"Selected template: {template_type}")

    # Style customization
    st.subheader("Template Style Customization")
    st.caption("Customize the appearance of the template")

    template_style_col1, template_style_col2 = st.columns(2)

    with template_style_col1:
        template_node_shape = st.selectbox(
            "Node Shape",
            ["box", "ellipse", "circle", "diamond", "plaintext", "triangle", "hexagon", "cylinder"],
            key="node_shape_select",
            on_change=reset_template_preview
        )

        template_node_color = st.color_picker(
            "Node Color",
            "#D0E8FF",
            key="node_color_picker",
            on_change=reset_template_preview
        )

    with template_style_col2:
        template_edge_style = st.selectbox(
            "Edge Style",
            ["solid", "dashed", "thick", "arrow"],
            key="edge_style_select",
            on_change=reset_template_preview
        )

        template_edge_color = st.color_picker(
            "Edge Color",
            "#666666",
            key="edge_color_picker",
            on_change=reset_template_preview
        )

    # Template preview generation
    if st.button("Preview Template"):
        st.session_state['template_previewed'] = True

    # Preview display (only shown when button is pressed)
    if st.session_state['template_previewed']:
        try:
            # Create temporary graph object
            preview_graph = graphviz.Digraph()

            # Set graph direction
            if graph_direction == "Left to Right (LR)":
                preview_graph.attr(rankdir="LR")
            else:
                preview_graph.attr(rankdir="TB")

            # Common style settings for templates
            node_attrs = {
                "shape": template_node_shape,
                "style": "filled",
                "fillcolor": template_node_color
            }

            edge_attrs = {"color": template_edge_color}

            if template_edge_style == "dashed":
                edge_attrs["style"] = "dashed"
            elif template_edge_style == "thick":
                edge_attrs["penwidth"] = "2"
            elif template_edge_style == "arrow":
                edge_attrs["arrowhead"] = "normal"

            # Temporary template nodes and edges
            template_nodes = []
            template_edges = []

            # Define nodes and edges for each template
            if template_type == "Simple System Architecture":
                template_nodes = [
                    {"id": "client", "label": "Client", "attrs": node_attrs},
                    {"id": "server", "label": "Server", "attrs": node_attrs},
                    {"id": "db", "label": "Database", "attrs": node_attrs}
                ]
                template_edges = [
                    {"source": "client", "target": "server", "label": "Request", "attrs": edge_attrs},
                    {"source": "server", "target": "db", "label": "Query", "attrs": edge_attrs},
                    {"source": "db", "target": "server", "label": "Result", "attrs": edge_attrs},
                    {"source": "server", "target": "client", "label": "Response", "attrs": edge_attrs}
                ]

            elif template_type == "Microservice Architecture":
                template_nodes = [
                    {"id": "api", "label": "API Gateway", "attrs": node_attrs},
                    {"id": "auth", "label": "Auth Service", "attrs": node_attrs},
                    {"id": "user", "label": "User Service", "attrs": node_attrs},
                    {"id": "order", "label": "Order Service", "attrs": node_attrs},
                    {"id": "payment", "label": "Payment Service", "attrs": node_attrs},
                    {"id": "db_user", "label": "User DB", "attrs": node_attrs},
                    {"id": "db_order", "label": "Order DB", "attrs": node_attrs}
                ]
                template_edges = [
                    {"source": "api", "target": "auth", "label": "Authenticate", "attrs": edge_attrs},
                    {"source": "api", "target": "user", "label": "User Management", "attrs": edge_attrs},
                    {"source": "api", "target": "order", "label": "Order Processing", "attrs": edge_attrs},
                    {"source": "order", "target": "payment", "label": "Payment Processing", "attrs": edge_attrs},
                    {"source": "user", "target": "db_user", "label": "Data Storage", "attrs": edge_attrs},
                    {"source": "order", "target": "db_order", "label": "Data Storage", "attrs": edge_attrs}
                ]

            elif template_type == "Data Pipeline":
                template_nodes = [
                    {"id": "source", "label": "Data Source", "attrs": node_attrs},
                    {"id": "ingest", "label": "Ingestion", "attrs": node_attrs},
                    {"id": "process", "label": "Processing", "attrs": node_attrs},
                    {"id": "analyze", "label": "Analysis", "attrs": node_attrs},
                    {"id": "store", "label": "Storage", "attrs": node_attrs},
                    {"id": "visual", "label": "Visualization", "attrs": node_attrs}
                ]
                template_edges = [
                    {"source": "source", "target": "ingest", "label": "Extract", "attrs": edge_attrs},
                    {"source": "ingest", "target": "process", "label": "Pre-process", "attrs": edge_attrs},
                    {"source": "process", "target": "analyze", "label": "Analyze", "attrs": edge_attrs},
                    {"source": "analyze", "target": "store", "label": "Store", "attrs": edge_attrs},
                    {"source": "store", "target": "visual", "label": "Visualize", "attrs": edge_attrs}
                ]

            elif template_type == "Network Configuration":
                template_nodes = [
                    {"id": "router", "label": "Router", "attrs": node_attrs},
                    {"id": "switch1", "label": "Switch 1", "attrs": node_attrs},
                    {"id": "switch2", "label": "Switch 2", "attrs": node_attrs},
                    {"id": "server1", "label": "Server 1", "attrs": node_attrs},
                    {"id": "server2", "label": "Server 2", "attrs": node_attrs},
                    {"id": "client1", "label": "Client 1", "attrs": node_attrs},
                    {"id": "client2", "label": "Client 2", "attrs": node_attrs}
                ]
                template_edges = [
                    {"source": "router", "target": "switch1", "label": "Connection", "attrs": edge_attrs},
                    {"source": "router", "target": "switch2", "label": "Connection", "attrs": edge_attrs},
                    {"source": "switch1", "target": "server1", "label": "Connection", "attrs": edge_attrs},
                    {"source": "switch1", "target": "server2", "label": "Connection", "attrs": edge_attrs},
                    {"source": "switch2", "target": "client1", "label": "Connection", "attrs": edge_attrs},
                    {"source": "switch2", "target": "client2", "label": "Connection", "attrs": edge_attrs}
                ]

            elif template_type == "Cloud Architecture":
                template_nodes = [
                    {"id": "lb", "label": "Load Balancer", "attrs": node_attrs},
                    {"id": "web1", "label": "Web Server 1", "attrs": node_attrs},
                    {"id": "web2", "label": "Web Server 2", "attrs": node_attrs},
                    {"id": "app1", "label": "App Server 1", "attrs": node_attrs},
                    {"id": "app2", "label": "App Server 2", "attrs": node_attrs},
                    {"id": "db_master", "label": "DB Master", "attrs": node_attrs},
                    {"id": "db_slave", "label": "DB Slave", "attrs": node_attrs}
                ]
                template_edges = [
                    {"source": "lb", "target": "web1", "label": "Forward", "attrs": edge_attrs},
                    {"source": "lb", "target": "web2", "label": "Forward", "attrs": edge_attrs},
                    {"source": "web1", "target": "app1", "label": "Request", "attrs": edge_attrs},
                    {"source": "web2", "target": "app2", "label": "Request", "attrs": edge_attrs},
                    {"source": "app1", "target": "db_master", "label": "Query", "attrs": edge_attrs},
                    {"source": "app2", "target": "db_master", "label": "Query", "attrs": edge_attrs},
                    {"source": "db_master", "target": "db_slave", "label": "Replication", "attrs": edge_attrs}
                ]

            # Add nodes and edges to graph
            for node in template_nodes:
                preview_graph.node(node["id"], label=node["label"], **node["attrs"])

            for edge in template_edges:
                preview_graph.edge(edge["source"], edge["target"], label=edge.get("label", ""), **edge["attrs"])

            # Preview display
            st.subheader("Template Preview")
            st.graphviz_chart(preview_graph)

            # Display template node and edge information in table format
            st.subheader("Template Nodes and Edges Information")

            template_nodes_tab, template_edges_tab = st.tabs(["Template Nodes List", "Template Edges List"])

            with template_nodes_tab:
                display_nodes_table(template_nodes, is_template=True)

            with template_edges_tab:
                display_edges_table(template_edges, is_template=True)

            # Display generated DOT code
            st.subheader("Generated DOT Code")
            st.code(preview_graph.source, language="dot")

            # Option to apply template to custom input
            st.subheader("Apply Template to Custom Input")

            # Use form for batch processing of template application
            with st.form(key="apply_template_form"):
                st.write(f"Apply {template_type} template to nodes and edges input.")
                apply_submit = st.form_submit_button("Apply Template")

                # Prepare graph direction for template
                direction = "LR" if graph_direction == "Left to Right (LR)" else "TB"

                if apply_submit:
                    # Apply template to session
                    apply_template_to_session(template_nodes, template_edges, template_node_shape, template_node_color, template_edge_style, template_edge_color, template_type)
                    st.success(f"Applied {template_type} template to 'Nodes and Edges Input' tab.")

        except Exception as e:
            st.error(f"Error generating template preview: {str(e)}")

# Tab 3: Direct DOT Code Input
elif selected_tab == "Direct DOT Code Input":
    st.subheader("Direct DOT Code Input")

    # Load DOT code from session state, update after input
    dot_code = st.text_area(
        "DOT Code",
        st.session_state['dot_code'],
        height=300
    )

    # Update session state
    st.session_state['dot_code'] = dot_code

    # Preview button
    if st.button("Preview DOT Code"):
        try:
            # Render DOT code with Graphviz
            graph = graphviz.Source(dot_code)
            st.subheader("Preview")
            st.graphviz_chart(graph)
        except Exception as e:
            st.error(f"Error parsing DOT code: {str(e)}")

# Tab 4: AI Generation
elif selected_tab == "AI Generation":
    st.subheader("AI-Generated Graphviz")

    if not use_llm:
        st.warning("To use AI generation, please check 'Enable AI Generation' in the sidebar.")
        dot_code = ""
    else:
        # AI prompt input (load from session state, update after input)
        ai_prompt = st.text_area(
            "Enter a description of the diagram you want to create",
            st.session_state['ai_prompt'],
            height=150
        )

        # Update session state
        st.session_state['ai_prompt'] = ai_prompt

        direction_for_ai = "LR" if "Left to Right" in graph_direction else "TB"

        # Generation button
        if st.button("Generate Graphviz Code with AI"):
            with st.spinner("AI is generating Graphviz code..."):
                try:
                    # Create AI prompt
                    ai_system_prompt = f"""
You are an expert in Graphviz DOT language. Generate appropriate Graphviz DOT code based on the user's request.
Follow these guidelines to generate the DOT code:
1. Use digraph format
2. Set rankdir={direction_for_ai} for direction
3. Set appropriate shapes for each node (box, circle, ellipse, diamond, cylinder, etc.)
4. Use appropriate styles for edges (solid, dashed, arrows, etc.)
5. Use colors effectively
6. Use readable labels
7. Create a well-organized and visually appealing graph
8. Return only DOT code without any explanation

User's request: {ai_prompt}

DOT code:
"""

                    # Generate DOT code with AI
                    dot_code = CompleteText(lang_model, ai_system_prompt)

                    # Format response (remove extra backticks)
                    dot_code = dot_code.strip()
                    if dot_code.startswith("```

dot"):
                        dot_code = dot_code[6:]
                    if dot_code.startswith("

```"):
                        dot_code = dot_code[3:]
                    if dot_code.endswith("```

"):
                        dot_code = dot_code[:-3]

                    # Display generated DOT code
                    st.subheader("Generated DOT Code")
                    st.code(dot_code, language="dot")

                    # Preview display
                    try:
                        graph = graphviz.Source(dot_code)
                        st.subheader("Preview")
                        st.graphviz_chart(graph)
                    except Exception as e:
                        st.error(f"Error in generated DOT code: {str(e)}")

                except Exception as e:
                    st.error(f"Error during AI generation: {str(e)}")
                    dot_code = ""

# Display current status in sidebar
st.sidebar.write(f"Current input method: {st.session_state['active_tab']}")

# Graph generation button
if selected_tab == "Nodes and Edges Input" and st.button("Generate Graph"):
    try:
        active_tab = st.session_state['active_tab']
        st.write(f"Graph generation: Using {active_tab} mode")

        # Create directed graph
        graph = graphviz.Digraph()

        # Set graph direction
        if graph_direction == "Left to Right (LR)":
            graph.attr(rankdir="LR")
        else:
            graph.attr(rankdir="TB")

        # Add nodes
        for node in st.session_state['nodes']:
            node_attrs = {
                "label": node['label'],
                "shape": node['shape'],
                "style": node['style'],
                "fillcolor": node['fillcolor'],
                "color": node['color'],
                "peripheries": str(node['peripheries']),
                "fontname": node['fontname'],
                "fontsize": str(node['fontsize']),
                "fontcolor": node['fontcolor'],
                "tooltip": node['tooltip']
            }
            graph.node(node['id'], **node_attrs)

        # Add edges
        for edge in st.session_state['edges']:
            edge_attrs = {}

            if edge['label']:
                edge_attrs["label"] = edge['label']

            # Style settings
            style = edge['style'].split(" ")[0]  # Extract "solid" from "solid (solid line)"
            if style == "dashed":
                edge_attrs["style"] = "dashed"
            elif style == "dotted":
                edge_attrs["style"] = "dotted"
            elif style == "bold":
                edge_attrs["penwidth"] = "2"

            # Color and arrow settings
            edge_attrs["color"] = edge['color']
            if edge['dir'] != "none" and edge['arrow'] != "none":
                edge_attrs["arrowhead"] = edge['arrow']

            # Direction settings
            if 'dir' in edge and edge['dir'] != 'forward':
                edge_attrs["dir"] = edge['dir']

            # Font settings
            if 'fontname' in edge:
                edge_attrs["fontname"] = edge['fontname']
            if 'fontsize' in edge:
                edge_attrs["fontsize"] = str(edge['fontsize'])
            if 'fontcolor' in edge:
                edge_attrs["fontcolor"] = edge['fontcolor']

            # Additional settings
            if 'tooltip' in edge and edge['tooltip']:
                edge_attrs["tooltip"] = edge['tooltip']
            if 'penwidth' in edge:
                edge_attrs["penwidth"] = str(edge['penwidth'])
            if 'weight' in edge:
                edge_attrs["weight"] = str(edge['weight'])
            if 'constraint' in edge:
                edge_attrs["constraint"] = "true" if edge['constraint'] else "false"
            if 'minlen' in edge:
                edge_attrs["minlen"] = str(edge['minlen'])

            graph.edge(edge['source'], edge['target'], **edge_attrs)

        # Display graph
        st.header("Generated Diagram")
        st.graphviz_chart(graph)

        # Display DOT code
        st.subheader("Generated DOT Code")
        st.code(graph.source, language="dot")

        # Export guidance
        st.info("To export the diagram, copy the DOT code above and save it to a file.")

    except Exception as e:
        st.error(f"Error during graph generation: {str(e)}")

# Help information
with st.expander("Graphviz and Diagram Creation Help"):
    st.markdown("""
    ### What is Graphviz?
    Graphviz is an open-source graph visualization software. It can be used to create various types of diagrams including directed and undirected graphs.

    ### Basic Usage
    1. Select graph direction in the sidebar
    2. Choose an input method: "Nodes and Edges Input", "Templates", "Direct DOT Code Input", or "AI Generation"
    3. Click "Generate Graph" to display the diagram

    ### Nodes and Edges Form Input
    - Nodes: Fill in the form and click "Add Node" to add nodes to the list
    - Edges: Fill in the form and click "Add Edge" to add connections between nodes
    - View and delete added nodes and edges in the lists below

    ### Commonly Used Node Shapes
    - box: Rectangle (default for many diagrams)
    - ellipse: Oval (default shape)
    - circle: Circle
    - diamond: Diamond shape
    - cylinder: Cylinder shape (often used for databases)
    - plaintext: Text only, no shape
    - polygon: Polygon
    - record: Record format (for multiple fields)

    ### Commonly Used Node Styles
    - filled: Filled with color
    - dashed: Dashed border
    - dotted: Dotted border
    - solid: Solid border
    - filled,rounded: Rounded corners with fill
    - dashed,filled: Dashed border with fill
    - dotted,filled: Dotted border with fill
    - bold: Bold border
    - invis: Invisible (connections only)

    ### Commonly Used Edge Styles
    - solid: Solid line (default)
    - dashed: Dashed line
    - dotted: Dotted line
    - bold: Bold line (achieved with penwidth="2")

    ### Basic DOT Language Syntax


    ```
    digraph G {  // for directed graphs
        A -> B;  // edge from A to B
        B -> C;  // edge from B to C
    }
    ```



    ### Node and Edge Attribute Examples


    ```
    digraph G {
        // Node settings
        A [label="Start", shape=box, style=filled, fillcolor=lightblue];

        // Edge settings
        A -> B [label="Flow", style=dashed, color=red];
    }
    ```



    ### Using AI Generation
    1. Check "Enable AI Generation" in the sidebar
    2. Select the Cortex AI model to use
    3. Go to the "AI Generation" tab and enter a description of the diagram you want
    4. Click "Generate Graphviz Code with AI" to create a diagram based on your description
    """)

# Footer
st.caption("Created by Tsubasa Kanno")


Enter fullscreen mode Exit fullscreen mode

Usage Examples

This application offers four functions:

  1. Node and Edge Input: Add nodes and edges interactively from the UI
  2. Templates: Select and customize existing templates
  3. Direct DOT Code Input: Directly edit Graphviz syntax
  4. AI Generation: Automatically generate flow diagrams from natural language descriptions

With the AI generation feature in particular, you can generate diagrams by simply entering prompts like this:



Please create a system configuration diagram for a web application.
Users access through browsers, through a load balancer to multiple web servers.
Web servers communicate with application servers, and application servers store data in database servers.
A cache server is also used.


Enter fullscreen mode Exit fullscreen mode

Graph generated by Cortex AI

Conclusion

By combining Streamlit in Snowflake's Graphviz library with Cortex AI, I was able to create an application that enables users to create high-quality flow diagrams and architecture diagrams without technical knowledge.

While multimodal generative AI has recently made it possible to read and generate diagrams, it's still difficult to use them without adjustments. However, by combining them with code-based diagram generation tools like Graphviz, we can expect to absorb the variations in reading and generating diagrams with generative AI. I encourage everyone to try this application and explore new ways to utilize generative AI.

Promotion

Snowflake What's New Updates on X

I'm sharing updates on Snowflake's What's New on X. I'd be happy if you could follow:

English Version

Snowflake What's New Bot (English Version)

Japanese Version

Snowflake's What's New Bot (Japanese Version)

Change History

(2025/03/20) Initial post

Original Japanese Article

https://zenn.dev/tsubasa_tech/articles/5238844010c097

Hostinger image

Get n8n VPS hosting 3x cheaper than a cloud solution

Get fast, easy, secure n8n VPS hosting from $4.99/mo at Hostinger. Automate any workflow using a pre-installed n8n application and no-code customization.

Start now

Top comments (0)

Sentry image

See why 4M developers consider Sentry, “not bad.”

Fixing code doesn’t have to be the worst part of your day. Learn how Sentry can help.

Learn more

👋 Kindness is contagious

If this article connected with you, consider tapping ❤️ or leaving a brief comment to share your thoughts!

Okay