DEV Community

Armaan Khan
Armaan Khan

Posted on

chart

chart
from typing import Dict

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import streamlit as st
from plotly.subplots import make_subplots

class ChartBuilder:
    """Builds charts and KPIs with error handling and fallback to tables"""

    @staticmethod
    def create_kpi_card(
        value, title: str, format_type: str = "number", color: str = "#1f77b4"
    ):
        """Creates a KPI metric card"""
        try:
            if pd.isna(value) or value is None:
                formatted_value = "N/A"
            else:
                if format_type == "currency":
                    formatted_value = f"${float(value):,.2f}"
                elif format_type == "percentage":
                    formatted_value = f"{float(value):.1f}%"
                else:
                    formatted_value = f"{float(value):,.0f}"

            # Using Streamlit's built-in metric
            st.metric(label=title, value=formatted_value)

        except Exception as e:
            st.error(f"Error creating KPI card: {str(e)}")
            st.metric(label=title, value="Error")

    @staticmethod
    def create_chart(df: pd.DataFrame, config: Dict) -> None:
        """
        Creates charts based on config with fallback to table on errors
        """
        if df.empty:
            st.warning("No data available for chart")
            return

        chart_type = config.get("chart_type", "table").lower()
        title = config.get("title", "Chart")

        try:
            if chart_type == "kpi":
                # Handle KPI from dataframe
                value_col = config.get("value_column")
                if value_col and value_col in df.columns:
                    value = df[value_col].iloc[0] if len(df) > 0 else 0
                    ChartBuilder.create_kpi_card(
                        value,
                        title,
                        config.get("format", "number"),
                        config.get("color", "#1f77b4"),
                    )
                else:
                    st.error("KPI value column not found")
                return

            elif chart_type == "bar":
                x_col = config.get("x_axis")
                y_col = config.get("y_axis")

                if (
                    not x_col
                    or not y_col
                    or x_col not in df.columns
                    or y_col not in df.columns
                ):
                    raise ValueError("Required columns not found for bar chart")

                fig = px.bar(
                    df,
                    x=x_col,
                    y=y_col,
                    title=title,
                    color_discrete_sequence=[config.get("color", "#1f77b4")],
                )
                fig.update_layout(showlegend=False)
                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "line":
                x_col = config.get("x_axis")
                y_col = config.get("y_axis")

                if (
                    not x_col
                    or not y_col
                    or x_col not in df.columns
                    or y_col not in df.columns
                ):
                    raise ValueError("Required columns not found for line chart")

                fig = px.line(
                    df,
                    x=x_col,
                    y=y_col,
                    title=title,
                    color_discrete_sequence=[config.get("color", "#1f77b4")],
                )
                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "pie":
                y_cols = config.get("y_axis")  # This is a list of metric columns
                stack_field = config.get("stack_field", "METRIC_TYPE")
                title = config.get("title", "Pie Chart")

                if not y_cols or not all(col in df.columns for col in y_cols):
                    raise ValueError("Required columns not found for pie chart")

                # Melt and aggregate
                df_melted = df.melt(
                    id_vars=["USER_NAME"] if "USER_NAME" in df.columns else ["user"],
                    value_vars=y_cols,
                    var_name=stack_field,
                    value_name="Value"
                )
                df_summary = df_melted.groupby(stack_field, as_index=False)["Value"].sum()
                df_summary.rename(columns={"Value": "TOTAL_COUNT"}, inplace=True)

                # Render pie chart
                fig = px.pie(df_summary, names=stack_field, values="TOTAL_COUNT", title=title)
                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "scatter":
                x_col = config.get("x_axis")
                y_col = config.get("y_axis")

                if (
                    not x_col
                    or not y_col
                    or x_col not in df.columns
                    or y_col not in df.columns
                ):
                    raise ValueError("Required columns not found for scatter chart")

                fig = px.scatter(
                    df,
                    x=x_col,
                    y=y_col,
                    title=title,
                    color_discrete_sequence=[config.get("color", "#1f77b4")],
                )
                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "area":
                x_col = config.get("x_axis")
                y_col = config.get("y_axis")

                if (
                    not x_col
                    or not y_col
                    or x_col not in df.columns
                    or y_col not in df.columns
                ):
                    raise ValueError("Required columns not found for area chart")

                fig = px.area(
                    df,
                    x=x_col,
                    y=y_col,
                    title=title,
                    color_discrete_sequence=[config.get("color", "#1f77b4")],
                )
                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "histogram":
                x_col = config.get("x_axis")

                if not x_col or x_col not in df.columns:
                    raise ValueError("Required column not found for histogram")

                fig = px.histogram(
                    df,
                    x=x_col,
                    title=title,
                    color_discrete_sequence=[config.get("color", "#1f77b4")],
                )
                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "box":
                y_col = config.get("y_axis")
                x_col = config.get("x_axis")  # optional for grouping

                if not y_col or y_col not in df.columns:
                    raise ValueError("Required column not found for box plot")

                fig = px.box(
                    df,
                    y=y_col,
                    x=x_col if x_col and x_col in df.columns else None,
                    title=title,
                    color_discrete_sequence=[config.get("color", "#1f77b4")],
                )
                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "violin":
                y_col = config.get("y_axis")
                x_col = config.get("x_axis")  # optional for grouping

                if not y_col or y_col not in df.columns:
                    raise ValueError("Required column not found for violin plot")

                fig = px.violin(
                    df,
                    y=y_col,
                    x=x_col if x_col and x_col in df.columns else None,
                    title=title,
                    color_discrete_sequence=[config.get("color", "#1f77b4")],
                )
                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "heatmap":
                # For heatmap, we'll use all numeric columns or specified columns
                numeric_cols = df.select_dtypes(include=[np.number]).columns
                if len(numeric_cols) < 2:
                    raise ValueError("Not enough numeric columns for heatmap")

                correlation_matrix = df[numeric_cols].corr()
                fig = px.imshow(
                    correlation_matrix, title=title, color_continuous_scale="RdBu_r"
                )
                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "donut":
                names_col = config.get("names_column", config.get("x_axis"))
                values_col = config.get("values_column", config.get("y_axis"))

                if (
                    not names_col
                    or not values_col
                    or names_col not in df.columns
                    or values_col not in df.columns
                ):
                    raise ValueError("Required columns not found for donut chart")

                fig = px.pie(
                    df,
                    names=names_col,
                    values=values_col,
                    title=title,
                    hole=0.4,  # This makes it a donut chart
                )
                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "funnel":
                x_col = config.get("x_axis")
                y_col = config.get("y_axis")

                if (
                    not x_col
                    or not y_col
                    or x_col not in df.columns
                    or y_col not in df.columns
                ):
                    raise ValueError("Required columns not found for funnel chart")

                fig = px.funnel(
                    df,
                    x=y_col,
                    y=x_col,
                    title=title,
                    color_discrete_sequence=[config.get("color", "#1f77b4")],
                )
                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "gauge":
                value_col = config.get("value_column", config.get("y_axis"))
                if not value_col or value_col not in df.columns:
                    raise ValueError("Required column not found for gauge chart")

                value = df[value_col].iloc[0] if len(df) > 0 else 0
                max_val = config.get("max_value", df[value_col].max())

                fig = go.Figure(
                    go.Indicator(
                        mode="gauge+number",
                        value=value,
                        domain={"x": [0, 1], "y": [0, 1]},
                        title={"text": title},
                        gauge={
                            "axis": {"range": [None, max_val]},
                            "bar": {"color": config.get("color", "#1f77b4")},
                            "steps": [
                                {"range": [0, max_val / 2], "color": "lightgray"},
                                {"range": [max_val / 2, max_val], "color": "gray"},
                            ],
                            "threshold": {
                                "line": {"color": "red", "width": 4},
                                "thickness": 0.75,
                                "value": max_val * 0.9,
                            },
                        },
                    )
                )
                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "treemap":
                path_col = config.get("path_column", config.get("x_axis"))
                values_col = config.get("values_column", config.get("y_axis"))

                if (
                    not path_col
                    or not values_col
                    or path_col not in df.columns
                    or values_col not in df.columns
                ):
                    raise ValueError("Required columns not found for treemap")

                fig = px.treemap(df, path=[path_col], values=values_col, title=title)
                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "sunburst":
                path_col = config.get("path_column", config.get("x_axis"))
                values_col = config.get("values_column", config.get("y_axis"))

                if (
                    not path_col
                    or not values_col
                    or path_col not in df.columns
                    or values_col not in df.columns
                ):
                    raise ValueError("Required columns not found for sunburst")

                fig = px.sunburst(df, path=[path_col], values=values_col, title=title)
                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "waterfall":
                x_col = config.get("x_axis")
                y_col = config.get("y_axis")

                if (
                    not x_col
                    or not y_col
                    or x_col not in df.columns
                    or y_col not in df.columns
                ):
                    raise ValueError("Required columns not found for waterfall chart")

                fig = go.Figure(
                    go.Waterfall(
                        name=title,
                        orientation="v",
                        measure=["relative"] * len(df),
                        x=df[x_col],
                        textposition="outside",
                        text=[f"{v:,.0f}" for v in df[y_col]],
                        y=df[y_col],
                        connector={"line": {"color": "rgb(63, 63, 63)"}},
                    )
                )
                fig.update_layout(title=title, showlegend=False)
                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "stacked_bar":
                x_col = config.get("x_axis")
                y_cols = config.get("y_axis")  # This is a list
                stack_field = config.get("stack_field", "Metric Type")
                title = config.get("title", "Stacked Bar Chart")

                # Validate columns
                if (
                    not x_col
                    or not y_cols
                    or x_col not in df.columns
                    or not all(col in df.columns for col in y_cols)
                ):
                    raise ValueError("Required columns not found for stacked bar chart")

                # Melt the DataFrame to long format
                df_melted = df.melt(
                    id_vars=[x_col],
                    value_vars=y_cols,
                    var_name=stack_field,
                    value_name="Value",
                )

                # Create the chart
                fig = px.bar(
                    df_melted,
                    x=x_col,
                    y="Value",
                    color=stack_field,
                    title=title,
                    barmode="stack",
                    color_discrete_map=config.get(
                        "colors", {}
                    ),  # Optional custom colors
                )

                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "grouped_bar":
                x_col = config.get("x_axis")
                color_col = config.get("color_column")
                y_col = config.get("y_axis")
                stack_field = config.get("stack_field", "Metric Type")

                # if (
                #     not x_col
                #     or not y_col
                #     or x_col not in df.columns
                #    or not all(col in df.columns for col in y_col)
                # ):
                #     raise ValueError("Required columns not found for grouped bar chart")

                df_melted = df.melt(
                    id_vars=[x_col],
                    value_vars=y_col,
                    var_name=stack_field,
                    value_name="Value",
                )

                fig = px.bar(
                    df_melted,
                    x=x_col,
                      y="Value",
                    color=stack_field,
                    title=title,
                    barmode="group",
                )
                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "horizontal_bar":
                x_col = config.get("x_axis")
                y_col = config.get("y_axis")

                if (
                    not x_col
                    or not y_col
                    or x_col not in df.columns
                    or y_col not in df.columns
                ):
                    raise ValueError(
                        "Required columns not found for horizontal bar chart"
                    )

                fig = px.bar(
                    df,
                    x=y_col,
                    y=x_col,
                    title=title,
                    orientation="h",
                    color_discrete_sequence=[config.get("color", "#1f77b4")],
                )
                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "multi_line":
                x_col = config.get("x_axis")
                y_cols = config.get("y_columns", [])  # List of columns

                if not x_col or not y_cols or x_col not in df.columns:
                    raise ValueError("Required columns not found for multi-line chart")

                fig = go.Figure()
                colors = [
                    "#1f77b4",
                    "#ff7f0e",
                    "#2ca02c",
                    "#d62728",
                    "#9467bd",
                    "#8c564b",
                ]

                for i, y_col in enumerate(y_cols):
                    if y_col in df.columns:
                        fig.add_trace(
                            go.Scatter(
                                x=df[x_col],
                                y=df[y_col],
                                mode="lines+markers",
                                name=y_col,
                                line=dict(color=colors[i % len(colors)]),
                            )
                        )

                fig.update_layout(title=title, xaxis_title=x_col, yaxis_title="Values")
                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "combo":
                x_col = config.get("x_axis")
                y_col = config.get("y_axis")
                line_col = config.get("line_column")

                if not all([x_col, y_col, line_col]) or not all(
                    [col in df.columns for col in [x_col, y_col, line_col]]
                ):
                    raise ValueError("Required columns not found for combo chart")

                fig = make_subplots(specs=[[{"secondary_y": True}]])

                # Add bar chart
                fig.add_trace(
                    go.Bar(x=df[x_col], y=df[y_col], name=y_col),
                    secondary_y=False,
                )

                # Add line chart
                fig.add_trace(
                    go.Scatter(
                        x=df[x_col], y=df[line_col], mode="lines+markers", name=line_col
                    ),
                    secondary_y=True,
                )

                fig.update_xaxes(title_text=x_col)
                fig.update_yaxes(title_text=y_col, secondary_y=False)
                fig.update_yaxes(title_text=line_col, secondary_y=True)
                fig.update_layout(title_text=title)

                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "bullet":
                value_col = config.get("value_column", config.get("y_axis"))
                target_col = config.get("target_column")

                if not value_col or value_col not in df.columns:
                    raise ValueError("Required column not found for bullet chart")

                value = df[value_col].iloc[0] if len(df) > 0 else 0
                target = (
                    df[target_col].iloc[0]
                    if target_col and target_col in df.columns
                    else value * 1.2
                )

                fig = go.Figure()

                fig.add_trace(
                    go.Indicator(
                        mode="number+gauge+delta",
                        value=value,
                        domain={"x": [0, 1], "y": [0, 1]},
                        title={"text": title},
                        delta={"reference": target},
                        gauge={
                            "shape": "bullet",
                            "axis": {"range": [None, target * 1.5]},
                            "threshold": {
                                "line": {"color": "red", "width": 2},
                                "thickness": 0.75,
                                "value": target,
                            },
                            "steps": [
                                {"range": [0, target * 0.5], "color": "lightgray"},
                                {"range": [target * 0.5, target], "color": "gray"},
                            ],
                            "bar": {"color": config.get("color", "#1f77b4")},
                        },
                    )
                )

                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "density_heatmap":
                x_col = config.get("x_axis")
                y_col = config.get("y_axis")

                if (
                    not x_col
                    or not y_col
                    or x_col not in df.columns
                    or y_col not in df.columns
                ):
                    raise ValueError("Required columns not found for density heatmap")

                fig = px.density_heatmap(
                    df,
                    x=x_col,
                    y=y_col,
                    title=title,
                    marginal_x="histogram",
                    marginal_y="histogram",
                )
                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "stacked_area":
                x_col = config.get("x_axis")
                color_col = config.get("color_column")
                y_col = config.get("y_axis")

                if (
                    not x_col
                    or not y_col
                    or x_col not in df.columns
                    or y_col not in df.columns
                ):
                    raise ValueError(
                        "Required columns not found for stacked area chart"
                    )

                fig = px.area(
                    df,
                    x=x_col,
                    y=y_col,
                    color=color_col if color_col and color_col in df.columns else None,
                    title=title,
                )
                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "radar":
                # For radar charts, we need multiple numeric columns
                categories_col = config.get("categories_column", config.get("x_axis"))
                values_cols = config.get("values_columns", [])

                if not categories_col or categories_col not in df.columns:
                    # Use numeric columns as categories if no specific column provided
                    numeric_cols = df.select_dtypes(
                        include=[np.number]
                    ).columns.tolist()
                    if len(numeric_cols) < 3:
                        raise ValueError("Not enough data for radar chart")

                    # Create radar chart from first row of numeric data
                    values = df[numeric_cols].iloc[0].values
                    categories = numeric_cols
                else:
                    # Use specified columns
                    if not values_cols:
                        values_cols = [
                            col
                            for col in df.columns
                            if col != categories_col
                            and pd.api.types.is_numeric_dtype(df[col])
                        ]

                    categories = df[categories_col].tolist()
                    values = df[values_cols[0]].tolist() if values_cols else []

                fig = go.Figure()

                fig.add_trace(
                    go.Scatterpolar(
                        r=values, theta=categories, fill="toself", name=title
                    )
                )

                fig.update_layout(
                    polar=dict(
                        radialaxis=dict(
                            visible=True,
                        )
                    ),
                    title=title,
                    showlegend=True,
                )

                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "candlestick":
                # Requires OHLC data
                date_col = config.get("date_column", config.get("x_axis"))
                open_col = config.get("open_column", "open")
                high_col = config.get("high_column", "high")
                low_col = config.get("low_column", "low")
                close_col = config.get("close_column", "close")

                required_cols = [date_col, open_col, high_col, low_col, close_col]
                if not all(col and col in df.columns for col in required_cols):
                    raise ValueError(
                        "Required OHLC columns not found for candlestick chart"
                    )

                fig = go.Figure(
                    data=go.Candlestick(
                        x=df[date_col],
                        open=df[open_col],
                        high=df[high_col],
                        low=df[low_col],
                        close=df[close_col],
                    )
                )

                fig.update_layout(title=title, xaxis_title=date_col)
                st.plotly_chart(fig, use_container_width=True)

            elif chart_type == "sankey":
                source_col = config.get("source_column")
                target_col = config.get("target_column")
                value_col = config.get("value_column", config.get("y_axis"))

                if not all([source_col, target_col, value_col]) or not all(
                    [col in df.columns for col in [source_col, target_col, value_col]]
                ):
                    raise ValueError("Required columns not found for sankey diagram")

                # Create unique node labels
                all_nodes = list(set(df[source_col].tolist() + df[target_col].tolist()))
                node_dict = {node: i for i, node in enumerate(all_nodes)}

                fig = go.Figure(
                    data=[
                        go.Sankey(
                            node=dict(
                                pad=15,
                                thickness=20,
                                line=dict(color="black", width=0.5),
                                label=all_nodes,
                            ),
                            link=dict(
                                source=[node_dict[src] for src in df[source_col]],
                                target=[node_dict[tgt] for tgt in df[target_col]],
                                value=df[value_col],
                            ),
                        )
                    ]
                )

                fig.update_layout(title_text=title, font_size=10)
                st.plotly_chart(fig, use_container_width=True)

            else:
                # Default to table or explicit table request
                ChartBuilder._create_table(df, title)

        except Exception as e:
            st.warning(
                f"Chart rendering failed, showing table instead. Error: {str(e)}"
            )
            ChartBuilder._create_table(df, title)

    @staticmethod
    def _create_table(df: pd.DataFrame, title: str):
        """Creates a formatted table display"""
        st.subheader(title)

        # Format numeric columns for better display
        formatted_df = df.copy()
        for col in formatted_df.columns:
            if pd.api.types.is_numeric_dtype(formatted_df[col]):
                if formatted_df[col].dtype in ["float64", "float32"]:
                    formatted_df[col] = formatted_df[col].round(2)

        st.dataframe(formatted_df, use_container_width=True, hide_index=True)
Enter fullscreen mode Exit fullscreen mode

Top comments (0)