DEV Community

James Casia
James Casia

Posted on

Deploying Pytorch Models with Flask & Amazon EC2

So you have finally trained and built your machine learning model. It took you days on end doing research, training and fine-tuning to achieve perfection. But in the end, a model is only as good as the good it causes for the people who use it. Thus deploying a machine learning model is an essential step as it allows one to put into action and see its real-world impact on the people it was designed to help.

In this article, we create a very basic API for our pytorch model. We use Flask to build the REST API and AWS EC2 as our deployment platform of choice. This article breaks down the task into four major steps.

Setup & Installing Dependencies

Creating & Using Virtual Environment(optional)

We will be using virtualenv for an isolated python environment. Run the following commands to install, create, and activate our virtual environment.

$ python -m pip install --user virtualenv
$ python -m venv env
$ source env/bin/activate
Enter fullscreen mode Exit fullscreen mode

Installing Requirements

Create a requirements.txt file with the following content.

//requirements.txt

torch==1.13.1
torchvision==0.14.1
Flask==2.2.2 
pandas==1.5.2 
tqdm==4.64.1
matplotlib==3.6.3
gunicorn==20.1.0
Enter fullscreen mode Exit fullscreen mode

Then simply run the command pip install -r requirements.txt to install the dependencies detailed in the requirements.txt file

Writing the code

Project structure

We follow the standard Flask project file structure. We have a static folder that contains our static assets such as CSS and images. We also have a templates directory that contains our HTML templates. Our main directory is app and it contains two files, __init__.py and main.py . For our purpose, we leave __init__.py empty and write the routes and other relevant code in main.py.

app/
    __init__.py
    main.py
static/
    /images
        /resnet/
templates/
requirements.txt
run.sh
Enter fullscreen mode Exit fullscreen mode

Importing dependencies

We import the essential dependencies.

# main.py
from torchvision.models import resnet50, ResNet50_Weights
from flask import Flask, request ,render_template, url_for
from werkzeug.utils import secure_filename 
import urllib.request,io 
from PIL import Image 
import os
Enter fullscreen mode Exit fullscreen mode

Writing the create_app function

In main.py, we write a create_app function that returns a Flask object.

# main.py
def create_app():
        app = Flask('image-classifier')

        ...
        ...

        return app
Enter fullscreen mode Exit fullscreen mode

Initializing the Pytorch Model

For this tutorial, we use a pre-trained resnet50 model available from torchvision for convenience. Feel free to replace this with your own model. We then call .eval() to toggle the model to inference mode.

# main.py
def create_app():

    app = Flask('image-classifier') 

    weights = ResNet50_Weights.DEFAULT
    model = resnet50(weights=weights)
    model.eval()
    preprocess = weights.transforms()
Enter fullscreen mode Exit fullscreen mode

For this model, we create a preprocess object that is basically a pre-processing transformation step that comes with the resnet50 model. We will use this preprocess object to modify the images we feed into the model.

Defining the routes

In this app, we have three routes, /gui and /classify . The /gui route is where we present the graphical user interface to interact with the API. The /classify route is for the classification API itself for which we can interact with it through a command line interface or a REST inspector application. The classify_gui function contains the code for the GUI interaction with the classification API. It is slightly different from the classify function as it contains an image upload feature.

def create_app():

        ...
        ...

    @app.route('/classify')
    def classify(): 
                ...
                ...

    @app.route('/')   
    def gui():
            ...
                ...

    @app.route('/', methods=['POST'])  
    def classify_gui():
            ...
                ...

    return app
Enter fullscreen mode Exit fullscreen mode

Defining the classification endpoint

Flask allows you to define the routes through the @app.route decorator. In the classify function, we implement what happens when a user performs a GET request in the /classify route. We nest this function in the create_app function for convenience. This allows us to have access to the model without passing it as parameters to classify().

# main.py
def create_app():

        ...
        ...

    @app.route('/classify')
    def classify(): 
        url = io.BytesIO(urllib.request.urlopen(request.get_json()['url']).read())
        img = Image.open(url)

        batch = preprocess(img).unsqueeze(0) 
        prediction = model(batch).squeeze(0).softmax(0)
        class_id = prediction.argmax().item()
        score = prediction[class_id].item()
        category_name = weights.meta["categories"][class_id] 
        return  { 'class': category_name, 'confidence': float("{:.2f}".format(score)) } 

Enter fullscreen mode Exit fullscreen mode

In this particular function, we accept a json containing a url value. We parse the url from the request body and load it as a Pillow image. We then do some model-specific pre- and post- processing methods. The resnet50 model by default accepts a batch of images. Since we’re only passing it one image, we perform unsqueeze to make it a batch of one image. After passing the batch to model, we extract the prediction by applying the softmax function and getting the index of the maximum prediction value.

Defining the GUI endpoint

The / route will contain the graphical interface for users to interact with the model. Here is how it will look like.

https://i.imgur.com/55IhIzq.png

We use render_template function to return html.

    @app.route('/')   
    def gui():
        return render_template('gui.html')
Enter fullscreen mode Exit fullscreen mode

We create the base.html file. The gui.html will inherit from this file

<!--base.html-->
<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta http-equiv="X-UA-Compatible" content="IE=edge">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">  
    <title>ML API</title>
</head>
<body>
{% block container %}{% endblock %}
</body>
</html>
Enter fullscreen mode Exit fullscreen mode

We create the gui.html file.

<!--gui.html-->
{% extends 'base.html' %}

{% block container %}
<div class="home"> 
<h1>Image Classification</h1>
<h3>with Pytorch ResNet50</h3>  
<h2>CLI</h2>
<table border="1">
    <tr>
        <td><strong>curl</strong></td>
        <td>curl --request GET \
            --url BASE_URL/classify \
            --header 'Content-Type: application/json' \
            --data '{
              "url": IMG_URL
          }'</td>
    </tr>
    <tr>
        <td><strong>wget</strong></td>
        <td>wget --quiet \
        --method POST \
        --header 'Content-Type: application/json' \
        --body-data '{\n    "url": IMG_URL \
        --output-document \
        - BASE_URL/classify</td>
    </tr>
     </table>
</div>

<h2>GUI</h2>
<form method="post" enctype="multipart/form-data" >
    {{img_path}}
    <input type="file" name="file" id="img_upload"><br><br>

    <img id="image" height="400px" src={{img_url}}>

    <br>
    <input type="submit" value="Classify">
</form> 

<br>
<div>{{result}}</div>
<script>
    document.getElementById('img_upload').onchange = function ( ) { 
   var src = URL.createObjectURL(this.files[0])
  document.getElementById('image').src = src
}
</script>
{% endblock %}
Enter fullscreen mode Exit fullscreen mode

Then we craft the classify_gui function. This function gets called when the user presses the classify button.

    @app.route('/', methods=['POST'])  
    def classify_gui():
        clear_dir('./static/images/') 
        if 'file' not in request.files:
            return render_template('gui.html', msg='No file found')

        file = request.files['file']
        if file.filename == '':
            return render_template('gui.html', msg='No file selected')

        if file and allowed_file(file.filename):
            filename = secure_filename(file.filename)
            path = f'static/images/{filename}'
            file.save(path)

            img = Image.open(path)    
            batch = preprocess(img).unsqueeze(0) 

            prediction = model(batch).squeeze(0).softmax(0)
            class_id = prediction.argmax().item()
            score = prediction[class_id].item()
            category_name = weights.meta["categories"][class_id]  

            return render_template(
                'gui.html', result = f"{category_name}: {100 * score:.1f}%", 
                img_url = f"{url_for('static', filename=f'/images/{filename}')}"
            )
Enter fullscreen mode Exit fullscreen mode

Code Summary

Here is the whole code for main.py

# main.py
from torchvision.models import resnet50, ResNet50_Weights
from flask import Flask, request ,render_template, url_for
from werkzeug.utils import secure_filename 
import urllib.request,io 
from PIL import Image 
import os 

def create_app():

    app = Flask('image-classifier') 

    weights = ResNet50_Weights.DEFAULT
    model = resnet50(weights=weights)
    model.eval()
    preprocess = weights.transforms()

    @app.route('/classify')
    def classify(): 
        url = io.BytesIO(urllib.request.urlopen(request.get_json()['url']).read())
        img = Image.open(url)

        batch = preprocess(img).unsqueeze(0) 
        prediction = model(batch).squeeze(0).softmax(0)
        class_id = prediction.argmax().item()
        score = prediction[class_id].item()
        category_name = weights.meta["categories"][class_id] 
        return  { 'class': category_name, 'confidence': float("{:.2f}".format(score)) } 

    @app.route('/')   
    def gui():
        return render_template('gui.html')

    @app.route('/', methods=['POST'])  
    def classify_gui():
        clear_dir('./static/images/') 
        if 'file' not in request.files:
            return render_template('gui.html', msg='No file found')

        file = request.files['file']
        if file.filename == '':
            return render_template('gui.html', msg='No file selected')

        if file and allowed_file(file.filename):
            filename = secure_filename(file.filename)
            path = f'static/images/{filename}'
            file.save(path)

            img = Image.open(path)    
            batch = preprocess(img).unsqueeze(0) 

            prediction = model(batch).squeeze(0).softmax(0)
            class_id = prediction.argmax().item()
            score = prediction[class_id].item()
            category_name = weights.meta["categories"][class_id]  

            return render_template(
                'gui.html', result = f"{category_name}: {100 * score:.1f}%", 
                img_url = f"{url_for('static', filename=f'/images/{filename}')}"
            ) 

    return app

def clear_dir(dir_path):
    for file in os.listdir(dir_path):
        os.remove(f'{dir_path}/{file}') 

def allowed_file(filename):
    return '.' in filename and filename.rsplit('.', 1)[1].lower() in {'png', 'jpg', 'jpeg'}
Enter fullscreen mode Exit fullscreen mode

This code is also available on github.

Localhost testing

At this point, one can already deploy and use the app in our localhost. Just create a run.sh file with the following commands. Setting the hostname is necessary as it allows any other computer besides the host to access our flask app.

# run.sh
source env/bin/activate
export FLASK_APP="app.main:create_app"
export FLASK_RUN_PORT=8080
export FLASK_APP_HOSTNAME="0.0.0.0" 
flask run --host=$FLASK_APP_HOSTNAME
Enter fullscreen mode Exit fullscreen mode

Then run the shell file

$ sh run.sh
Enter fullscreen mode Exit fullscreen mode

To create a request, simply paste this sample request to another terminal window.

$ curl --request GET \
  --url localhost:8080/classify \
  --header 'Content-Type: application/json' \
  --data '{
    "url": "https://external-content.duckduckgo.com/iu/?u=https%3A%2F%2F2.bp.blogspot.com%2F-HB23AuZMkCc%2FUmFMY7DNM8I%2FAAAAAAAAA8Q%2FFD8D50slVJ4%2Fs1600%2FAlbatross-Bird-Pic.jpg&f=1&nofb=1&ipt=0ac75b04d5d023fa362e560745360a7aad5e15c90e2673d887ff40851e71d9f3&ipo=images"
}'
Enter fullscreen mode Exit fullscreen mode

Or launch your favorite API inspector, set the content type header to application/json and supply an object with the url of your chosen image. Voila! You have now created a REST API for your machine learning model!!

https://i.imgur.com/43vqwbm.png

In this image, we pass a url of an image and receive a prediction. The app being used here is Insomnia

Deploying to the Cloud

Deploying your API to your localhost is one thing. Deploying your API to the totality of the interwebs is a whole different matter! Suddenly you have to worry about a whole slew of complexities such as DNS, Inbound & Outbound rules, virtual machines and so much more! Good thing as there are a bajillion articles and tutorials on the internet just about that!

Creating EC2 Instance

Since there are a lot of EC2 tutorials we can simply follow this one on youtube, https://www.youtube.com/watch?v=oqHfiRzxunY. Make sure to select Ubuntu Server AMI, and t2.micro as they are available in the free-tier.

Modifying the security group and Inbound rules

We add an inbound rule to allow other computers to communicate with our instance. We select Custom TCP as the type and 8080 as the port range. We set the source to 0.0.0.0/0 to allow any device from any ip address.

https://i.imgur.com/DqhiJHG.png

Opening the port in the firewall

Run the following commands in the EC2 terminal to open port 8080.

$ sudo ufw allow 8080
$ sudo ufw enable
Enter fullscreen mode Exit fullscreen mode

Getting your code into the EC2 instance

There are two methods that to get your code into the EC2 instance. One way way is through SCP which is a linux command that allows for file transfer to authenticated users. Another quick but dirty way is to create a public git repository, this way you don’t have to deal with linux complexities. In this article, we’ll simply use the repository route. We simply create a git repository containing the code, publish it, and clone the public repository on the EC2 instance.

Deploying

Run the following commands to download pip and install the necessary libraries.

$ sudo apt update
$ sudo apt install python3-pip
$ pip install -r requirements.txt
Enter fullscreen mode Exit fullscreen mode

Then simply run run.sh again and type in http://PUBLIC_IP_ADDRESS:8080 to any web browser. Use your EC2 instance’s public ip address. Remember to use http instead of https. Voila! You can now access your machine learning API in any part of the world.

Ensuring scalability and security (Part 2 Coming Soon)

We have deployed our Flask app but the Flask team themselves do not recommend deploying bare Flask apps to production. This is because Flask does not come with built-in support for load balancing, security, database management etc. That will be discussed in another article though as this article has gone too long.

Top comments (0)