DEV Community

Armaan Khan
Armaan Khan

Posted on

ne chart builder

from typing import Dict, List, Union

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import streamlit as st
from plotly.subplots import make_subplots

class ChartBuilder:
    """Builds charts and KPIs with error handling and fallback to tables"""

    @staticmethod
    def create_kpi_card(
        value, title: str, format_type: str = "number", color: str = "#1f77b4"
    ):
        """Creates a KPI metric card"""
        try:
            if pd.isna(value) or value is None:
                formatted_value = "N/A"
            else:
                if format_type == "currency":
                    formatted_value = f"${float(value):,.2f}"
                elif format_type == "percentage":
                    formatted_value = f"{float(value):.1f}%"
                else:
                    formatted_value = f"{float(value):,.0f}"

            # Using Streamlit's built-in metric
            st.metric(label=title, value=formatted_value)

        except Exception as e:
            st.error(f"Error creating KPI card: {str(e)}")
            st.metric(label=title, value="Error")

    @staticmethod
    def create_metric_cards(
        data: Union[pd.DataFrame, Dict], 
        config: Dict,
        tooltip_fields: List[str] = None
    ) -> None:
        """
        Creates multiple metric cards from DataFrame or direct values

        Args:
            data: DataFrame or dict with values
            config: Configuration with columns/keys and formatting
            tooltip_fields: List of tooltip descriptions for each metric
        """
        try:
            columns = config.get("columns", [])
            titles = config.get("titles", columns)
            formats = config.get("formats", ["number"] * len(columns))
            colors = config.get("colors", ["#1f77b4"] * len(columns))

            if not columns:
                st.error("No columns specified for metric cards")
                return

            # Create columns for layout
            cols = st.columns(len(columns))

            for i, (col, title, fmt, color) in enumerate(zip(columns, titles, formats, colors)):
                with cols[i]:
                    try:
                        # Get value from DataFrame or dict
                        if isinstance(data, pd.DataFrame):
                            if col in data.columns:
                                value = data[col].iloc[0] if len(data) > 0 else 0
                            else:
                                value = 0
                                st.warning(f"Column '{col}' not found")
                        else:
                            value = data.get(col, 0)

                        # Format value
                        if pd.isna(value) or value is None:
                            formatted_value = "N/A"
                        else:
                            if fmt == "currency":
                                formatted_value = f"${float(value):,.2f}"
                            elif fmt == "percentage":
                                formatted_value = f"{float(value):.1f}%"
                            else:
                                formatted_value = f"{float(value):,.0f}"

                        # Create metric with tooltip
                        st.metric(
                            label=title, 
                            value=formatted_value,
                            help=tooltip_fields[i] if tooltip_fields and i < len(tooltip_fields) else None
                        )

                    except Exception as e:
                        st.metric(label=title, value="Error", help=f"Error: {str(e)}")

        except Exception as e:
            st.error(f"Error creating metric cards: {str(e)}")

    @staticmethod
    def create_recommendation_table(
        df: pd.DataFrame, 
        config: Dict,
        tooltip_fields: List[str] = None
    ) -> None:
        """
        Creates an enhanced recommendation table with better visualization

        Args:
            df: DataFrame with cause and recommendation data
            config: Configuration with x_axis (cause) and y_axis (recommendation)
            tooltip_fields: List of tooltip descriptions
        """
        try:
            if df.empty:
                st.warning("No recommendation data available")
                return

            title = config.get("title", "Recommendations")
            cause_col = config.get("x_axis", "cause")
            recommendation_col = config.get("y_axis", "recommendation")

            # Validate columns exist
            if cause_col not in df.columns or recommendation_col not in df.columns:
                st.error(f"Required columns '{cause_col}' or '{recommendation_col}' not found")
                ChartBuilder._create_table(df, title)
                return

            st.subheader(title)

            # Create enhanced recommendation display
            for idx, row in df.iterrows():
                cause = row[cause_col]
                recommendation = row[recommendation_col]

                # Create expandable recommendation card
                with st.expander(f"🔍 {cause}", expanded=False):
                    col1, col2 = st.columns([1, 3])

                    with col1:
                        # Status/Priority indicator
                        priority = row.get('priority', 'Medium')
                        if priority.lower() == 'high':
                            st.error("🔴 High Priority")
                        elif priority.lower() == 'low':
                            st.success("🟢 Low Priority")
                        else:
                            st.warning("🟡 Medium Priority")

                    with col2:
                        # Recommendation details
                        st.markdown(f"**Recommendation:**")
                        st.info(recommendation)

                        # Additional details if available
                        if 'impact' in row and pd.notna(row['impact']):
                            st.markdown(f"**Expected Impact:** {row['impact']}")

                        if 'timeline' in row and pd.notna(row['timeline']):
                            st.markdown(f"**Timeline:** {row['timeline']}")

                        # Tooltip information
                        if tooltip_fields and idx < len(tooltip_fields):
                            st.help(tooltip_fields[idx])

            # Summary metrics if numeric data available
            numeric_cols = df.select_dtypes(include=[np.number]).columns
            if len(numeric_cols) > 0:
                st.markdown("---")
                st.markdown("**Summary Metrics:**")

                summary_cols = st.columns(len(numeric_cols[:4]))  # Limit to 4 metrics
                for i, col in enumerate(numeric_cols[:4]):
                    with summary_cols[i]:
                        avg_val = df[col].mean()
                        st.metric(
                            label=col.replace('_', ' ').title(),
                            value=f"{avg_val:.2f}",
                            help=f"Average {col.replace('_', ' ').lower()}"
                        )

            # Action items visualization
            if 'status' in df.columns:
                st.markdown("---")
                st.markdown("**Action Status Overview:**")

                status_counts = df['status'].value_counts()
                fig = px.pie(
                    values=status_counts.values,
                    names=status_counts.index,
                    title="Recommendation Status Distribution",
                    color_discrete_map={
                        'Completed': '#00CC96',
                        'In Progress': '#FFA15A', 
                        'Pending': '#EF553B',
                        'Not Started': '#636EFA'
                    }
                )
                st.plotly_chart(fig, use_container_width=True)

        except Exception as e:
            st.error(f"Error creating recommendation table: {str(e)}")
            ChartBuilder._create_table(df, title)

    @staticmethod
    def create_chart(df: pd.DataFrame, config: Dict) -> None:
        """
        Creates charts based on config with fallback to table on errors
        """
        if df.empty:
            st.warning("No data available for chart")
            return

        chart_type = config.get("chart_type", "table").lower()
        title = config.get("title", "Chart")
        tooltip_fields = config.get("tooltip_fields", [])

        try:
            if chart_type == "kpi":
                # Handle KPI from dataframe
                value_col = config.get("value_column")
                if value_col and value_col in df.columns:
                    value = df[value_col].iloc[0] if len(df) > 0 else 0
                    ChartBuilder.create_kpi_card(
                        value,
                        title,
                        config.get("format", "number"),
                        config.get("color", "#1f77b4"),
                    )
                else:
                    st.error("KPI value column not found")
                return

            elif chart_type == "metric_cards":
                # Handle multiple metric cards
                ChartBuilder.create_metric_cards(df, config, tooltip_fields)
                return

            elif chart_type == "recommendation_table":
                # Handle recommendation table
                ChartBuilder.create_recommendation_table(df, config, tooltip_fields)
                return

            elif chart_type == "bar":
                x_col = config.get("x_axis")
                y_col = config.get("y_axis")

                if (
                    not x_col
                    or not y_col
                    or x_col not in df.columns
                    or y_col not in df.columns
                ):
                    raise ValueError("Required columns not found for bar chart")

                fig = px.bar(
                    df,
                    x=x_col,
                    y=y_col,
                    title=title,
                    color_discrete_sequence=[config.get("color", "#1f77b4")],
                )

                # Add tooltips if provided
                if tooltip_fields:
                    fig.update_traces(
                        hovertemplate="<br>".join([f"{field}: %{{customdata[{i}]}}" for i, field in enumerate(tooltip_fields)]) + "<extra></extra>",
                        customdata=df[tooltip_fields].values if all(field in df.columns for field in tooltip_fields) else None
                    )

                fig.update_layout(showlegend=False)
                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "line":
                x_col = config.get("x_axis")
                y_col = config.get("y_axis")

                if (
                    not x_col
                    or not y_col
                    or x_col not in df.columns
                    or y_col not in df.columns
                ):
                    raise ValueError("Required columns not found for line chart")

                fig = px.line(
                    df,
                    x=x_col,
                    y=y_col,
                    title=title,
                    color_discrete_sequence=[config.get("color", "#1f77b4")],
                )

                # Add tooltips if provided
                if tooltip_fields:
                    fig.update_traces(
                        hovertemplate="<br>".join([f"{field}: %{{customdata[{i}]}}" for i, field in enumerate(tooltip_fields)]) + "<extra></extra>",
                        customdata=df[tooltip_fields].values if all(field in df.columns for field in tooltip_fields) else None
                    )

                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "pie":
                y_cols = config.get("y_axis")  # This is a list of metric columns
                stack_field = config.get("stack_field", "METRIC_TYPE")
                title = config.get("title", "Pie Chart")

                if not y_cols or not all(col in df.columns for col in y_cols):
                    raise ValueError("Required columns not found for pie chart")

                # Melt and aggregate
                df_melted = df.melt(
                    id_vars=["USER_NAME"] if "USER_NAME" in df.columns else ["user"],
                    value_vars=y_cols,
                    var_name=stack_field,
                    value_name="Value"
                )
                df_summary = df_melted.groupby(stack_field, as_index=False)["Value"].sum()
                df_summary.rename(columns={"Value": "TOTAL_COUNT"}, inplace=True)

                # Render pie chart
                fig = px.pie(df_summary, names=stack_field, values="TOTAL_COUNT", title=title)
                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "scatter":
                x_col = config.get("x_axis")
                y_col = config.get("y_axis")

                if (
                    not x_col
                    or not y_col
                    or x_col not in df.columns
                    or y_col not in df.columns
                ):
                    raise ValueError("Required columns not found for scatter chart")

                fig = px.scatter(
                    df,
                    x=x_col,
                    y=y_col,
                    title=title,
                    color_discrete_sequence=[config.get("color", "#1f77b4")],
                )

                # Add tooltips if provided
                if tooltip_fields:
                    fig.update_traces(
                        hovertemplate="<br>".join([f"{field}: %{{customdata[{i}]}}" for i, field in enumerate(tooltip_fields)]) + "<extra></extra>",
                        customdata=df[tooltip_fields].values if all(field in df.columns for field in tooltip_fields) else None
                    )

                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "area":
                x_col = config.get("x_axis")
                y_col = config.get("y_axis")

                if (
                    not x_col
                    or not y_col
                    or x_col not in df.columns
                    or y_col not in df.columns
                ):
                    raise ValueError("Required columns not found for area chart")

                fig = px.area(
                    df,
                    x=x_col,
                    y=y_col,
                    title=title,
                    color_discrete_sequence=[config.get("color", "#1f77b4")],
                )
                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "histogram":
                x_col = config.get("x_axis")

                if not x_col or x_col not in df.columns:
                    raise ValueError("Required column not found for histogram")

                fig = px.histogram(
                    df,
                    x=x_col,
                    title=title,
                    color_discrete_sequence=[config.get("color", "#1f77b4")],
                )
                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "box":
                y_col = config.get("y_axis")
                x_col = config.get("x_axis")  # optional for grouping

                if not y_col or y_col not in df.columns:
                    raise ValueError("Required column not found for box plot")

                fig = px.box(
                    df,
                    y=y_col,
                    x=x_col if x_col and x_col in df.columns else None,
                    title=title,
                    color_discrete_sequence=[config.get("color", "#1f77b4")],
                )
                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "violin":
                y_col = config.get("y_axis")
                x_col = config.get("x_axis")  # optional for grouping

                if not y_col or y_col not in df.columns:
                    raise ValueError("Required column not found for violin plot")

                fig = px.violin(
                    df,
                    y=y_col,
                    x=x_col if x_col and x_col in df.columns else None,
                    title=title,
                    color_discrete_sequence=[config.get("color", "#1f77b4")],
                )
                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "heatmap":
                # For heatmap, we'll use all numeric columns or specified columns
                numeric_cols = df.select_dtypes(include=[np.number]).columns
                if len(numeric_cols) < 2:
                    raise ValueError("Not enough numeric columns for heatmap")

                correlation_matrix = df[numeric_cols].corr()
                fig = px.imshow(
                    correlation_matrix, title=title, color_continuous_scale="RdBu_r"
                )
                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "donut":
                names_col = config.get("names_column", config.get("x_axis"))
                values_col = config.get("values_column", config.get("y_axis"))

                if (
                    not names_col
                    or not values_col
                    or names_col not in df.columns
                    or values_col not in df.columns
                ):
                    raise ValueError("Required columns not found for donut chart")

                fig = px.pie(
                    df,
                    names=names_col,
                    values=values_col,
                    title=title,
                    hole=0.4,  # This makes it a donut chart
                )
                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "funnel":
                x_col = config.get("x_axis")
                y_col = config.get("y_axis")

                if (
                    not x_col
                    or not y_col
                    or x_col not in df.columns
                    or y_col not in df.columns
                ):
                    raise ValueError("Required columns not found for funnel chart")

                fig = px.funnel(
                    df,
                    x=y_col,
                    y=x_col,
                    title=title,
                    color_discrete_sequence=[config.get("color", "#1f77b4")],
                )
                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "gauge":
                value_col = config.get("value_column", config.get("y_axis"))
                if not value_col or value_col not in df.columns:
                    raise ValueError("Required column not found for gauge chart")

                value = df[value_col].iloc[0] if len(df) > 0 else 0
                max_val = config.get("max_value", df[value_col].max())

                fig = go.Figure(
                    go.Indicator(
                        mode="gauge+number",
                        value=value,
                        domain={"x": [0, 1], "y": [0, 1]},
                        title={"text": title},
                        gauge={
                            "axis": {"range": [None, max_val]},
                            "bar": {"color": config.get("color", "#1f77b4")},
                            "steps": [
                                {"range": [0, max_val / 2], "color": "lightgray"},
                                {"range": [max_val / 2, max_val], "color": "gray"},
                            ],
                            "threshold": {
                                "line": {"color": "red", "width": 4},
                                "thickness": 0.75,
                                "value": max_val * 0.9,
                            },
                        },
                    )
                )
                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "treemap":
                path_col = config.get("path_column", config.get("x_axis"))
                values_col = config.get("values_column", config.get("y_axis"))

                if (
                    not path_col
                    or not values_col
                    or path_col not in df.columns
                    or values_col not in df.columns
                ):
                    raise ValueError("Required columns not found for treemap")

                fig = px.treemap(df, path=[path_col], values=values_col, title=title)
                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "sunburst":
                path_col = config.get("path_column", config.get("x_axis"))
                values_col = config.get("values_column", config.get("y_axis"))

                if (
                    not path_col
                    or not values_col
                    or path_col not in df.columns
                    or values_col not in df.columns
                ):
                    raise ValueError("Required columns not found for sunburst")

                fig = px.sunburst(df, path=[path_col], values=values_col, title=title)
                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "waterfall":
                x_col = config.get("x_axis")
                y_col = config.get("y_axis")

                if (
                    not x_col
                    or not y_col
                    or x_col not in df.columns
                    or y_col not in df.columns
                ):
                    raise ValueError("Required columns not found for waterfall chart")

                fig = go.Figure(
                    go.Waterfall(
                        name=title,
                        orientation="v",
                        measure=["relative"] * len(df),
                        x=df[x_col],
                        textposition="outside",
                        text=[f"{v:,.0f}" for v in df[y_col]],
                        y=df[y_col],
                        connector={"line": {"color": "rgb(63, 63, 63)"}},
                    )
                )
                fig.update_layout(title=title, showlegend=False)
                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "stacked_bar":
                x_col = config.get("x_axis")
                y_cols = config.get("y_axis")  # This is a list
                stack_field = config.get("stack_field", "Metric Type")
                title = config.get("title", "Stacked Bar Chart")

                # Validate columns
                if (
                    not x_col
                    or not y_cols
                    or x_col not in df.columns
                    or not all(col in df.columns for col in y_cols)
                ):
                    raise ValueError("Required columns not found for stacked bar chart")

                # Melt the DataFrame to long format
                df_melted = df.melt(
                    id_vars=[x_col],
                    value_vars=y_cols,
                    var_name=stack_field,
                    value_name="Value",
                )

                # Create the chart
                fig = px.bar(
                    df_melted,
                    x=x_col,
                    y="Value",
                    color=stack_field,
                    title=title,
                    barmode="stack",
                    color_discrete_map=config.get(
                        "colors", {}
                    ),  # Optional custom colors
                )

                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "grouped_bar":
                x_col = config.get("x_axis")
                color_col = config.get("color_column")
                y_col = config.get("y_axis")
                stack_field = config.get("stack_field", "Metric Type")

                # if (
                #     not x_col
                #     or not y_col
                #     or x_col not in df.columns
                #    or not all(col in df.columns for col in y_col)
                # ):
                #     raise ValueError("Required columns not found for grouped bar chart")

                df_melted = df.melt(
                    id_vars=[x_col],
                    value_vars=y_col,
                    var_name=stack_field,
                    value_name="Value",
                )

                fig = px.bar(
                    df_melted,
                    x=x_col,
                      y="Value",
                    color=stack_field,
                    title=title,
                    barmode="group",
                )
                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "horizontal_bar":
                x_col = config.get("x_axis")
                y_col = config.get("y_axis")

                if (
                    not x_col
                    or not y_col
                    or x_col not in df.columns
                    or y_col not in df.columns
                ):
                    raise ValueError(
                        "Required columns not found for horizontal bar chart"
                    )

                fig = px.bar(
                    df,
                    x=y_col,
                    y=x_col,
                    title=title,
                    orientation="h",
                    color_discrete_sequence=[config.get("color", "#1f77b4")],
                )
                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "multi_line":
                x_col = config.get("x_axis")
                y_cols = config.get("y_columns", [])  # List of columns

                if not x_col or not y_cols or x_col not in df.columns:
                    raise ValueError("Required columns not found for multi-line chart")

                fig = go.Figure()
                colors = [
                    "#1f77b4",
                    "#ff7f0e",
                    "#2ca02c",
                    "#d62728",
                    "#9467bd",
                    "#8c564b",
                ]

                for i, y_col in enumerate(y_cols):
                    if y_col in df.columns:
                        fig.add_trace(
                            go.Scatter(
                                x=df[x_col],
                                y=df[y_col],
                                mode="lines+markers",
                                name=y_col,
                                line=dict(color=colors[i % len(colors)]),
                            )
                        )

                fig.update_layout(title=title, xaxis_title=x_col, yaxis_title="Values")
                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "combo":
                x_col = config.get("x_axis")
                y_col = config.get("y_axis")
                line_col = config.get("line_column")

                if not all([x_col, y_col, line_col]) or not all(
                    [col in df.columns for col in [x_col, y_col, line_col]]
                ):
                    raise ValueError("Required columns not found for combo chart")

                fig = make_subplots(specs=[[{"secondary_y": True}]])

                # Add bar chart
                fig.add_trace(
                    go.Bar(x=df[x_col], y=df[y_col], name=y_col),
                    secondary_y=False,
                )

                # Add line chart
                fig.add_trace(
                    go.Scatter(
                        x=df[x_col], y=df[line_col], mode="lines+markers", name=line_col
                    ),
                    secondary_y=True,
                )

                fig.update_xaxes(title_text=x_col)
                fig.update_yaxes(title_text=y_col, secondary_y=False)
                fig.update_yaxes(title_text=line_col, secondary_y=True)
                fig.update_layout(title_text=title)

                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "bullet":
                value_col = config.get("value_column", config.get("y_axis"))
                target_col = config.get("target_column")

                if not value_col or value_col not in df.columns:
                    raise ValueError("Required column not found for bullet chart")

                value = df[value_col].iloc[0] if len(df) > 0 else 0
                target = (
                    df[target_col].iloc[0]
                    if target_col and target_col in df.columns
                    else value * 1.2
                )

                fig = go.Figure()

                fig.add_trace(
                    go.Indicator(
                        mode="number+gauge+delta",
                        value=value,
                        domain={"x": [0, 1], "y": [0, 1]},
                        title={"text": title},
                        delta={"reference": target},
                        gauge={
                            "shape": "bullet",
                            "axis": {"range": [None, target * 1.5]},
                            "threshold": {
                                "line": {"color": "red", "width": 2},
                                "thickness": 0.75,
                                "value": target,
                            },
                            "steps": [
                                {"range": [0, target * 0.5], "color": "lightgray"},
                                {"range": [target * 0.5, target], "color": "gray"},
                            ],
                            "bar": {"color": config.get("color", "#1f77b4")},
                        },
                    )
                )

                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "density_heatmap":
                x_col = config.get("x_axis")
                y_col = config.get("y_axis")

                if (
                    not x_col
                    or not y_col
                    or x_col not in df.columns
                    or y_col not in df.columns
                ):
                    raise ValueError("Required columns not found for density heatmap")

                fig = px.density_heatmap(
                    df,
                    x=x_col,
                    y=y_col,
                    title=title,
                    marginal_x="histogram",
                    marginal_y="histogram",
                )
                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "stacked_area":
                x_col = config.get("x_axis")
                color_col = config.get("color_column")
                y_col = config.get("y_axis")

                if (
                    not x_col
                    or not y_col
                    or x_col not in df.columns
                    or y_col not in df.columns
                ):
                    raise ValueError(
                        "Required columns not found for stacked area chart"
                    )

                fig = px.area(
                    df,
                    x=x_col,
                    y=y_col,
                    color=color_col if color_col and color_col in df.columns else None,
                    title=title,
                )
                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "radar":
                # For radar charts, we need multiple numeric columns
                categories_col = config.get("categories_column", config.get("x_axis"))
                values_cols = config.get("values_columns", [])
Enter fullscreen mode Exit fullscreen mode

Top comments (0)