DEV Community

Armaan Khan
Armaan Khan

Posted on • Edited on

ne chart builder

from typing import Dict, List, Union

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import streamlit as st
from plotly.subplots import make_subplots

class ChartBuilder:
    """Builds charts and KPIs with error handling and fallback to tables"""

    @staticmethod
    def create_kpi_card(
        value, title: str, format_type: str = "number", color: str = "#1f77b4"
    ):
        """Creates a KPI metric card"""
        try:
            if pd.isna(value) or value is None:
                formatted_value = "N/A"
            else:
                if format_type == "currency":
                    formatted_value = f"${float(value):,.2f}"
                elif format_type == "percentage":
                    formatted_value = f"{float(value):.1f}%"
                else:
                    formatted_value = f"{float(value):,.0f}"

            # Using Streamlit's built-in metric
            st.metric(label=title, value=formatted_value)

        except Exception as e:
            st.error(f"Error creating KPI card: {str(e)}")
            st.metric(label=title, value="Error")

    @staticmethod
    def create_metric_cards(
        data: Union[pd.DataFrame, Dict], 
        config: Dict,
        tooltip_fields: List[str] = None
    ) -> None:
        """
        Creates multiple metric cards from DataFrame or direct values

        Args:
            data: DataFrame or dict with values
            config: Configuration with columns/keys and formatting
            tooltip_fields: List of tooltip descriptions for each metric
        """
        try:
            columns = config.get("columns", [])
            titles = config.get("titles", columns)
            formats = config.get("formats", ["number"] * len(columns))
            colors = config.get("colors", ["#1f77b4"] * len(columns))

            if not columns:
                st.error("No columns specified for metric cards")
                return

            # Create columns for layout
            cols = st.columns(len(columns))

            for i, (col, title, fmt, color) in enumerate(zip(columns, titles, formats, colors)):
                with cols[i]:
                    try:
                        # Get value from DataFrame or dict
                        if isinstance(data, pd.DataFrame):
                            if col in data.columns:
                                value = data[col].iloc[0] if len(data) > 0 else 0
                            else:
                                value = 0
                                st.warning(f"Column '{col}' not found")
                        else:
                            value = data.get(col, 0)

                        # Format value
                        if pd.isna(value) or value is None:
                            formatted_value = "N/A"
                        else:
                            if fmt == "currency":
                                formatted_value = f"${float(value):,.2f}"
                            elif fmt == "percentage":
                                formatted_value = f"{float(value):.1f}%"
                            else:
                                formatted_value = f"{float(value):,.0f}"

                        # Create metric with tooltip
                        st.metric(
                            label=title, 
                            value=formatted_value,
                            help=tooltip_fields[i] if tooltip_fields and i < len(tooltip_fields) else None
                        )

                    except Exception as e:
                        st.metric(label=title, value="Error", help=f"Error: {str(e)}")

        except Exception as e:
            st.error(f"Error creating metric cards: {str(e)}")

    @staticmethod
    def create_recommendation_table(
        df: pd.DataFrame, 
        config: Dict,
        tooltip_fields: List[str] = None
    ) -> None:
        """
        Creates an enhanced recommendation table with better visualization

        Args:
            df: DataFrame with cause and recommendation data
            config: Configuration with x_axis (cause) and y_axis (recommendation)
            tooltip_fields: List of tooltip descriptions
        """
        try:
            if df.empty:
                st.warning("No recommendation data available")
                return

            title = config.get("title", "Recommendations")
            cause_col = config.get("x_axis", "cause")
            recommendation_col = config.get("y_axis", "recommendation")

            # Validate columns exist
            if cause_col not in df.columns or recommendation_col not in df.columns:
                st.error(f"Required columns '{cause_col}' or '{recommendation_col}' not found")
                ChartBuilder._create_table(df, title)
                return

            st.subheader(title)

            # Create enhanced recommendation display
            for idx, row in df.iterrows():
                cause = row[cause_col]
                recommendation = row[recommendation_col]

                # Create expandable recommendation card
                with st.expander(f"🔍 {cause}", expanded=False):
                    col1, col2 = st.columns([1, 3])

                    with col1:
                        # Status/Priority indicator
                        priority = row.get('priority', 'Medium')
                        if priority.lower() == 'high':
                            st.error("🔴 High Priority")
                        elif priority.lower() == 'low':
                            st.success("🟢 Low Priority")
                        else:
                            st.warning("🟡 Medium Priority")

                    with col2:
                        # Recommendation details
                        st.markdown(f"**Recommendation:**")
                        st.info(recommendation)

                        # Additional details if available
                        if 'impact' in row and pd.notna(row['impact']):
                            st.markdown(f"**Expected Impact:** {row['impact']}")

                        if 'timeline' in row and pd.notna(row['timeline']):
                            st.markdown(f"**Timeline:** {row['timeline']}")

                        # Tooltip information
                        if tooltip_fields and idx < len(tooltip_fields):
                            st.help(tooltip_fields[idx])

            # Summary metrics if numeric data available
            numeric_cols = df.select_dtypes(include=[np.number]).columns
            if len(numeric_cols) > 0:
                st.markdown("---")
                st.markdown("**Summary Metrics:**")

                summary_cols = st.columns(len(numeric_cols[:4]))  # Limit to 4 metrics
                for i, col in enumerate(numeric_cols[:4]):
                    with summary_cols[i]:
                        avg_val = df[col].mean()
                        st.metric(
                            label=col.replace('_', ' ').title(),
                            value=f"{avg_val:.2f}",
                            help=f"Average {col.replace('_', ' ').lower()}"
                        )

            # Action items visualization
            if 'status' in df.columns:
                st.markdown("---")
                st.markdown("**Action Status Overview:**")

                status_counts = df['status'].value_counts()
                fig = px.pie(
                    values=status_counts.values,
                    names=status_counts.index,
                    title="Recommendation Status Distribution",
                    color_discrete_map={
                        'Completed': '#00CC96',
                        'In Progress': '#FFA15A', 
                        'Pending': '#EF553B',
                        'Not Started': '#636EFA'
                    }
                )
                st.plotly_chart(fig, use_container_width=True)

        except Exception as e:
            st.error(f"Error creating recommendation table: {str(e)}")
            ChartBuilder._create_table(df, title)

    @staticmethod
    def create_chart(df: pd.DataFrame, config: Dict) -> None:
        """
        Creates charts based on config with fallback to table on errors
        """
        if df.empty:
            st.warning("No data available for chart")
            return

        chart_type = config.get("chart_type", "table").lower()
        title = config.get("title", "Chart")
        tooltip_fields = config.get("tooltip_fields", [])

        try:
            if chart_type == "kpi":
                # Handle KPI from dataframe
                value_col = config.get("value_column")
                if value_col and value_col in df.columns:
                    value = df[value_col].iloc[0] if len(df) > 0 else 0
                    ChartBuilder.create_kpi_card(
                        value,
                        title,
                        config.get("format", "number"),
                        config.get("color", "#1f77b4"),
                    )
                else:
                    st.error("KPI value column not found")
                return

            elif chart_type == "metric_cards":
                # Handle multiple metric cards
                ChartBuilder.create_metric_cards(df, config, tooltip_fields)
                return

            elif chart_type == "recommendation_table":
                # Handle recommendation table
                ChartBuilder.create_recommendation_table(df, config, tooltip_fields)
                return

            elif chart_type == "bar":
                x_col = config.get("x_axis")
                y_col = config.get("y_axis")

                if (
                    not x_col
                    or not y_col
                    or x_col not in df.columns
                    or y_col not in df.columns
                ):
                    raise ValueError("Required columns not found for bar chart")

                fig = px.bar(
                    df,
                    x=x_col,
                    y=y_col,
                    title=title,
                    color_discrete_sequence=[config.get("color", "#1f77b4")],
                )

                # Add tooltips if provided
                if tooltip_fields:
                    fig.update_traces(
                        hovertemplate="<br>".join([f"{field}: %{{customdata[{i}]}}" for i, field in enumerate(tooltip_fields)]) + "<extra></extra>",
                        customdata=df[tooltip_fields].values if all(field in df.columns for field in tooltip_fields) else None
                    )

                fig.update_layout(showlegend=False)
                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "line":
                x_col = config.get("x_axis")
                y_col = config.get("y_axis")

                if (
                    not x_col
                    or not y_col
                    or x_col not in df.columns
                    or y_col not in df.columns
                ):
                    raise ValueError("Required columns not found for line chart")

                fig = px.line(
                    df,
                    x=x_col,
                    y=y_col,
                    title=title,
                    color_discrete_sequence=[config.get("color", "#1f77b4")],
                )

                # Add tooltips if provided
                if tooltip_fields:
                    fig.update_traces(
                        hovertemplate="<br>".join([f"{field}: %{{customdata[{i}]}}" for i, field in enumerate(tooltip_fields)]) + "<extra></extra>",
                        customdata=df[tooltip_fields].values if all(field in df.columns for field in tooltip_fields) else None
                    )

                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "pie":
                y_cols = config.get("y_axis")  # This is a list of metric columns
                stack_field = config.get("stack_field", "METRIC_TYPE")
                title = config.get("title", "Pie Chart")

                if not y_cols or not all(col in df.columns for col in y_cols):
                    raise ValueError("Required columns not found for pie chart")

                # Melt and aggregate
                df_melted = df.melt(
                    id_vars=["USER_NAME"] if "USER_NAME" in df.columns else ["user"],
                    value_vars=y_cols,
                    var_name=stack_field,
                    value_name="Value"
                )
                df_summary = df_melted.groupby(stack_field, as_index=False)["Value"].sum()
                df_summary.rename(columns={"Value": "TOTAL_COUNT"}, inplace=True)

                # Render pie chart
                fig = px.pie(df_summary, names=stack_field, values="TOTAL_COUNT", title=title)
                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "scatter":
                x_col = config.get("x_axis")
                y_col = config.get("y_axis")

                if (
                    not x_col
                    or not y_col
                    or x_col not in df.columns
                    or y_col not in df.columns
                ):
                    raise ValueError("Required columns not found for scatter chart")

                fig = px.scatter(
                    df,
                    x=x_col,
                    y=y_col,
                    title=title,
                    color_discrete_sequence=[config.get("color", "#1f77b4")],
                )

                # Add tooltips if provided
                if tooltip_fields:
                    fig.update_traces(
                        hovertemplate="<br>".join([f"{field}: %{{customdata[{i}]}}" for i, field in enumerate(tooltip_fields)]) + "<extra></extra>",
                        customdata=df[tooltip_fields].values if all(field in df.columns for field in tooltip_fields) else None
                    )

                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "area":
                x_col = config.get("x_axis")
                y_col = config.get("y_axis")

                if (
                    not x_col
                    or not y_col
                    or x_col not in df.columns
                    or y_col not in df.columns
                ):
                    raise ValueError("Required columns not found for area chart")

                fig = px.area(
                    df,
                    x=x_col,
                    y=y_col,
                    title=title,
                    color_discrete_sequence=[config.get("color", "#1f77b4")],
                )
                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "histogram":
                x_col = config.get("x_axis")

                if not x_col or x_col not in df.columns:
                    raise ValueError("Required column not found for histogram")

                fig = px.histogram(
                    df,
                    x=x_col,
                    title=title,
                    color_discrete_sequence=[config.get("color", "#1f77b4")],
                )
                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "box":
                y_col = config.get("y_axis")
                x_col = config.get("x_axis")  # optional for grouping

                if not y_col or y_col not in df.columns:
                    raise ValueError("Required column not found for box plot")

                fig = px.box(
                    df,
                    y=y_col,
                    x=x_col if x_col and x_col in df.columns else None,
                    title=title,
                    color_discrete_sequence=[config.get("color", "#1f77b4")],
                )
                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "violin":
                y_col = config.get("y_axis")
                x_col = config.get("x_axis")  # optional for grouping

                if not y_col or y_col not in df.columns:
                    raise ValueError("Required column not found for violin plot")

                fig = px.violin(
                    df,
                    y=y_col,
                    x=x_col if x_col and x_col in df.columns else None,
                    title=title,
                    color_discrete_sequence=[config.get("color", "#1f77b4")],
                )
                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "heatmap":
                # For heatmap, we'll use all numeric columns or specified columns
                numeric_cols = df.select_dtypes(include=[np.number]).columns
                if len(numeric_cols) < 2:
                    raise ValueError("Not enough numeric columns for heatmap")

                correlation_matrix = df[numeric_cols].corr()
                fig = px.imshow(
                    correlation_matrix, title=title, color_continuous_scale="RdBu_r"
                )
                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "donut":
                names_col = config.get("names_column", config.get("x_axis"))
                values_col = config.get("values_column", config.get("y_axis"))

                if (
                    not names_col
                    or not values_col
                    or names_col not in df.columns
                    or values_col not in df.columns
                ):
                    raise ValueError("Required columns not found for donut chart")

                fig = px.pie(
                    df,
                    names=names_col,
                    values=values_col,
                    title=title,
                    hole=0.4,  # This makes it a donut chart
                )
                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "funnel":
                x_col = config.get("x_axis")
                y_col = config.get("y_axis")

                if (
                    not x_col
                    or not y_col
                    or x_col not in df.columns
                    or y_col not in df.columns
                ):
                    raise ValueError("Required columns not found for funnel chart")

                fig = px.funnel(
                    df,
                    x=y_col,
                    y=x_col,
                    title=title,
                    color_discrete_sequence=[config.get("color", "#1f77b4")],
                )
                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "gauge":
                value_col = config.get("value_column", config.get("y_axis"))
                if not value_col or value_col not in df.columns:
                    raise ValueError("Required column not found for gauge chart")

                value = df[value_col].iloc[0] if len(df) > 0 else 0
                max_val = config.get("max_value", df[value_col].max())

                fig = go.Figure(
                    go.Indicator(
                        mode="gauge+number",
                        value=value,
                        domain={"x": [0, 1], "y": [0, 1]},
                        title={"text": title},
                        gauge={
                            "axis": {"range": [None, max_val]},
                            "bar": {"color": config.get("color", "#1f77b4")},
                            "steps": [
                                {"range": [0, max_val / 2], "color": "lightgray"},
                                {"range": [max_val / 2, max_val], "color": "gray"},
                            ],
                            "threshold": {
                                "line": {"color": "red", "width": 4},
                                "thickness": 0.75,
                                "value": max_val * 0.9,
                            },
                        },
                    )
                )
                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "treemap":
                path_col = config.get("path_column", config.get("x_axis"))
                values_col = config.get("values_column", config.get("y_axis"))

                if (
                    not path_col
                    or not values_col
                    or path_col not in df.columns
                    or values_col not in df.columns
                ):
                    raise ValueError("Required columns not found for treemap")

                fig = px.treemap(df, path=[path_col], values=values_col, title=title)
                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "sunburst":
                path_col = config.get("path_column", config.get("x_axis"))
                values_col = config.get("values_column", config.get("y_axis"))

                if (
                    not path_col
                    or not values_col
                    or path_col not in df.columns
                    or values_col not in df.columns
                ):
                    raise ValueError("Required columns not found for sunburst")

                fig = px.sunburst(df, path=[path_col], values=values_col, title=title)
                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "waterfall":
                x_col = config.get("x_axis")
                y_col = config.get("y_axis")

                if (
                    not x_col
                    or not y_col
                    or x_col not in df.columns
                    or y_col not in df.columns
                ):
                    raise ValueError("Required columns not found for waterfall chart")

                fig = go.Figure(
                    go.Waterfall(
                        name=title,
                        orientation="v",
                        measure=["relative"] * len(df),
                        x=df[x_col],
                        textposition="outside",
                        text=[f"{v:,.0f}" for v in df[y_col]],
                        y=df[y_col],
                        connector={"line": {"color": "rgb(63, 63, 63)"}},
                    )
                )
                fig.update_layout(title=title, showlegend=False)
                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "stacked_bar":
                x_col = config.get("x_axis")
                y_cols = config.get("y_axis")  # This is a list
                stack_field = config.get("stack_field", "Metric Type")
                title = config.get("title", "Stacked Bar Chart")

                # Validate columns
                if (
                    not x_col
                    or not y_cols
                    or x_col not in df.columns
                    or not all(col in df.columns for col in y_cols)
                ):
                    raise ValueError("Required columns not found for stacked bar chart")

                # Melt the DataFrame to long format
                df_melted = df.melt(
                    id_vars=[x_col],
                    value_vars=y_cols,
                    var_name=stack_field,
                    value_name="Value",
                )

                # Create the chart
                fig = px.bar(
                    df_melted,
                    x=x_col,
                    y="Value",
                    color=stack_field,
                    title=title,
                    barmode="stack",
                    color_discrete_map=config.get(
                        "colors", {}
                    ),  # Optional custom colors
                )

                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "grouped_bar":
                x_col = config.get("x_axis")
                color_col = config.get("color_column")
                y_col = config.get("y_axis")
                stack_field = config.get("stack_field", "Metric Type")

                # if (
                #     not x_col
                #     or not y_col
                #     or x_col not in df.columns
                #    or not all(col in df.columns for col in y_col)
                # ):
                #     raise ValueError("Required columns not found for grouped bar chart")

                df_melted = df.melt(
                    id_vars=[x_col],
                    value_vars=y_col,
                    var_name=stack_field,
                    value_name="Value",
                )

                fig = px.bar(
                    df_melted,
                    x=x_col,
                      y="Value",
                    color=stack_field,
                    title=title,
                    barmode="group",
                )
                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "horizontal_bar":
                x_col = config.get("x_axis")
                y_col = config.get("y_axis")

                if (
                    not x_col
                    or not y_col
                    or x_col not in df.columns
                    or y_col not in df.columns
                ):
                    raise ValueError(
                        "Required columns not found for horizontal bar chart"
                    )

                fig = px.bar(
                    df,
                    x=y_col,
                    y=x_col,
                    title=title,
                    orientation="h",
                    color_discrete_sequence=[config.get("color", "#1f77b4")],
                )
                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "multi_line":
                x_col = config.get("x_axis")
                y_cols = config.get("y_columns", [])  # List of columns

                if not x_col or not y_cols or x_col not in df.columns:
                    raise ValueError("Required columns not found for multi-line chart")

                fig = go.Figure()
                colors = [
                    "#1f77b4",
                    "#ff7f0e",
                    "#2ca02c",
                    "#d62728",
                    "#9467bd",
                    "#8c564b",
                ]

                for i, y_col in enumerate(y_cols):
                    if y_col in df.columns:
                        fig.add_trace(
                            go.Scatter(
                                x=df[x_col],
                                y=df[y_col],
                                mode="lines+markers",
                                name=y_col,
                                line=dict(color=colors[i % len(colors)]),
                            )
                        )

                fig.update_layout(title=title, xaxis_title=x_col, yaxis_title="Values")
                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "combo":
                x_col = config.get("x_axis")
                y_col = config.get("y_axis")
                line_col = config.get("line_column")

                if not all([x_col, y_col, line_col]) or not all(
                    [col in df.columns for col in [x_col, y_col, line_col]]
                ):
                    raise ValueError("Required columns not found for combo chart")

                fig = make_subplots(specs=[[{"secondary_y": True}]])

                # Add bar chart
                fig.add_trace(
                    go.Bar(x=df[x_col], y=df[y_col], name=y_col),
                    secondary_y=False,
                )

                # Add line chart
                fig.add_trace(
                    go.Scatter(
                        x=df[x_col], y=df[line_col], mode="lines+markers", name=line_col
                    ),
                    secondary_y=True,
                )

                fig.update_xaxes(title_text=x_col)
                fig.update_yaxes(title_text=y_col, secondary_y=False)
                fig.update_yaxes(title_text=line_col, secondary_y=True)
                fig.update_layout(title_text=title)

                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "bullet":
                value_col = config.get("value_column", config.get("y_axis"))
                target_col = config.get("target_column")

                if not value_col or value_col not in df.columns:
                    raise ValueError("Required column not found for bullet chart")

                value = df[value_col].iloc[0] if len(df) > 0 else 0
                target = (
                    df[target_col].iloc[0]
                    if target_col and target_col in df.columns
                    else value * 1.2
                )

                fig = go.Figure()

                fig.add_trace(
                    go.Indicator(
                        mode="number+gauge+delta",
                        value=value,
                        domain={"x": [0, 1], "y": [0, 1]},
                        title={"text": title},
                        delta={"reference": target},
                        gauge={
                            "shape": "bullet",
                            "axis": {"range": [None, target * 1.5]},
                            "threshold": {
                                "line": {"color": "red", "width": 2},
                                "thickness": 0.75,
                                "value": target,
                            },
                            "steps": [
                                {"range": [0, target * 0.5], "color": "lightgray"},
                                {"range": [target * 0.5, target], "color": "gray"},
                            ],
                            "bar": {"color": config.get("color", "#1f77b4")},
                        },
                    )
                )

                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "density_heatmap":
                x_col = config.get("x_axis")
                y_col = config.get("y_axis")

                if (
                    not x_col
                    or not y_col
                    or x_col not in df.columns
                    or y_col not in df.columns
                ):
                    raise ValueError("Required columns not found for density heatmap")

                fig = px.density_heatmap(
                    df,
                    x=x_col,
                    y=y_col,
                    title=title,
                    marginal_x="histogram",
                    marginal_y="histogram",
                )
                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "stacked_area":
                x_col = config.get("x_axis")
                color_col = config.get("color_column")
                y_col = config.get("y_axis")

                if (
                    not x_col
                    or not y_col
                    or x_col not in df.columns
                    or y_col not in df.columns
                ):
                    raise ValueError(
                        "Required columns not found for stacked area chart"
                    )

                fig = px.area(
                    df,
                    x=x_col,
                    y=y_col,
                    color=color_col if color_col and color_col in df.columns else None,
                    title=title,
                )
                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "radar":
                # For radar charts, we need multiple numeric columns
                categories_col = config.get("categories_column", config.get("x_axis"))
                values_cols = config.get("values_columns", [])
Enter fullscreen mode Exit fullscreen mode

Top comments (0)