DEV Community

Cover image for Gradio X Mapbox
Chris Whong for Mapbox

Posted on

Gradio X Mapbox

A developer on r/mapbox asked for if anyone had successfully integrated a Mapbox GL JS map with Gradio, a python library for creating web interfaces for machine learning models.

After a quick peek at Gradio's docs, I quickly found the sections on adding custom CSS and JS. With a bit of help from AI, I learned the best way to get data out of python world and into JavaScript world in a Gradio app (there are several ways, but the one I used involves stashing it in a hidden output textarea)

I built a demo to show a simple implementation of Gradio X Mapbox:

  • Start with a pandas dataframe of U.S. cities with their longitude/latitude coordinates.
  • Convert the city points to GeoJSON in python, dropping the GeoJSON in a hidden output.
  • On map load, pick up the GeoJSON and add circles to the map representing the cities using a geojson source and a fill layer.
  • Add a button to trigger some processing in python. In this case, it's the calculation of a convex hull around the point data.
  • Convert the convex hull to GeoJSON and stash it in another hidden output.
  • Render the convex hull polygon on the on the map using a geojson source and a fill layer.

There are a few considerations for loading Mapbox GL JS:

  • Load Mapbox GL JS and its css via CDN links using the head argument of demo.launch
  • Create a map container div using gr.HTML()
  • Before instantiating the map, poll for the existence of the map container (it may not exist when the JavaScript code first runs). In the example below, there is a timeout and subsequent checks for setting up the map.

The full code snippet is below if you want to give it a try. You'll need a Mapbox access token to get it working, so sign up here if you don't already have a Mapbox account.

import gradio as gr
import pandas as pd
import numpy as np
import json
from shapely.geometry import MultiPoint, mapping, Point


# cities DataFrame
df = pd.DataFrame({
    "city": [
        "New York",
        "Los Angeles",
        "Chicago",
        "Houston",
        "Phoenix",
        "Philadelphia",
        "San Antonio",
        "San Diego",
        "Dallas",
        "San Jose",
    ],
    "lat": [
        40.7128,
        34.0522,
        41.8781,
        29.7604,
        33.4484,
        39.9526,
        29.4241,
        32.7157,
        32.7767,
        37.3382,
    ],
    "lon": [
        -74.0060,
        -118.2437,
        -87.6298,
        -95.3698,
        -112.0740,
        -75.1652,
        -98.4936,
        -117.1611,
        -96.7970,
        -121.8863,
    ],
})


# compute convex hull and return GeoJSON
def compute_convex_hull(data):
    points = [Point(row["lon"], row["lat"]) for _, row in data.iterrows()]
    multipoint = MultiPoint(points)
    hull = multipoint.convex_hull
    hull_geojson = mapping(hull)
    features = [
        {
            "type": "Feature",
            "geometry": mapping(pt),
            "properties": {"city": row["city"]},
        }
        for pt, (_, row) in zip(points, data.iterrows())
    ]
    geojson = {
        "type": "FeatureCollection",
        "features": [
            {
                "type": "Feature",
                "geometry": hull_geojson,
                "properties": {"type": "convex_hull"},
            }
        ]
        + features,
    }
    return json.dumps(geojson)


# convert points to GeoJSON for initial display
def get_points_geojson(data):
    """Convert dataframe to GeoJSON for initial point display"""
    features = [
        {
            "type": "Feature",
            "geometry": {"type": "Point", "coordinates": [row["lon"], row["lat"]]},
            "properties": {"city": row["city"]},
        }
        for _, row in data.iterrows()
    ]
    return json.dumps({"type": "FeatureCollection", "features": features})


# Gradio interface, two columns: left data table and button, right map display
with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            gr.Markdown(
                """
                # US Cities Convex Hull Demo
                This demo computes the convex hull of a set of US cities and outputs the result as GeoJSON for display on a [Mapbox GL JS](https://docs.mapbox.com/mapbox-gl-js/) map.

                GeoJSON strings are assembled in python and stored in hidden Gradio outputs for easy access in the Javascript map code.
                """
            )
            gr.Dataframe(df, label="City Locations")
            btn = gr.Button("Compute Convex Hull GeoJSON")


        # this html includes the map container, Mapbox GL JS will add a map in this div
        with gr.Column():
            gr.HTML(
                """
                <div>
                    <h3>Mapbox GL JS Visualization</h3>
                    <div id='map-container' style='width: 100%; height: 450px; background-color: lightgray;'></div>
                    <div>© Mapbox &nbsp; © OpenStreetMap</div>
                </div>
                """
            )


    # empty output boxes to hold GeoJSON data
    geojson_output = gr.JSON(visible=False)
    initial_points = gr.JSON(value=get_points_geojson(df), visible=False)


    def handle_click():
        return compute_convex_hull(df)


    # on click, compute convex hull and update map
    btn.click(
        handle_click,
        outputs=geojson_output,
    ).then(
        None,
        inputs=geojson_output,
        outputs=None,
        js="""
        (data) => {
            console.log("Received GeoJSON data:", data);
            if (window.map) {
                // Remove existing hull layer if it exists
                if (window.map.getLayer('convex-hull')) {
                    window.map.removeLayer('convex-hull');
                }
                if (window.map.getSource('hull')) {
                    window.map.removeSource('hull');
                }


                // Add hull
                window.map.addSource('hull', {
                    type: 'geojson',
                    data: data,
                });


                window.map.addLayer({
                    id: 'convex-hull',
                    type: 'fill',
                    source: 'hull',
                    paint: {
                        'fill-color': '#088',
                        'fill-opacity': 0.3,
                        'fill-emissive-strength': 1,
                    },
                });


                window.map.addLayer({
                    id: 'convex-hull-outline',
                    type: 'line',
                    source: 'hull',
                    paint: {
                        'line-color': '#FFF',
                        'line-width': 2,
                        'line-emissive-strength': 1,
                    },
                });
            }
        }
        """,
    )


    # on initial load, initialize a map, then add points as a GeoJSON source and circle layer
    # initMap() contains a retry mechanism to wait for the map container to be available,
    # as Gradio may take some time to render the HTML and the map container div may not be present immediately
    demo.load(
        None,
        inputs=initial_points,
        outputs=None,
        js="""
        (pointsData) => {
            function initMap() {
                const container = document.getElementById('map-container');
                if (!container) {
                    setTimeout(initMap, 100);
                    return;
                }


                mapboxgl.accessToken = 'YOUR_MAPBOX_ACCESS_TOKEN';


                window.map = new mapboxgl.Map({
                    container: 'map-container',
                    style: 'mapbox://styles/mapbox/standard',
                    config: {
                        basemap: {
                            theme: 'monochrome',
                            lightPreset: 'night',
                        },
                    },
                    center: [-98.92906, 40.25617],
                    zoom: 2.5,
                    projection: 'mercator',
                    attributionControl: false,
                });


                // Add points after map loads
                window.map.on('load', function () {
                    window.map.addSource('cities', {
                        type: 'geojson',
                        data: pointsData,
                    });


                    window.map.addLayer({
                        id: 'city-points',
                        type: 'circle',
                        source: 'cities',
                        paint: {
                            'circle-radius': 6,
                            'circle-color': '#B42222',
                            'circle-emissive-strength': 1,
                        },
                    });
                });
            }


            initMap();
        }
        """,
    )


# load mapbox gl js and its css in the head
head = """
<link href="https://api.mapbox.com/mapbox-gl-js/v3.18.1/mapbox-gl.css" rel="stylesheet">
<script src="https://api.mapbox.com/mapbox-gl-js/v3.18.1/mapbox-gl.js"></script>
"""


# custom css to position the Mapbox logo, as Gradio css interferes with default positioning
# note that attributionControl is set to false in the map options because Gradio CSS also interferes with the attribution control
# map attribution is required and was manually added in the HTML above
custom_css = """
.mapboxgl-ctrl-bottom-left .mapboxgl-ctrl,
.mapboxgl-ctrl-left .mapboxgl-ctrl {
    margin: 0 0 10px 10px !important;
}
"""


demo.launch(head=head, css=custom_css)
Enter fullscreen mode Exit fullscreen mode

Top comments (0)