So you have finally trained and built your machine learning model. It took you days on end doing research, training and fine-tuning to achieve perfection. But in the end, a model is only as good as the good it causes for the people who use it. Thus deploying a machine learning model is an essential step as it allows one to put into action and see its real-world impact on the people it was designed to help.
In this article, we create a very basic API for our pytorch model. We use Flask to build the REST API and AWS EC2 as our deployment platform of choice. This article breaks down the task into four major steps.
Setup & Installing Dependencies
Creating & Using Virtual Environment(optional)
We will be using virtualenv
for an isolated python environment. Run the following commands to install, create, and activate our virtual environment.
$ python -m pip install --user virtualenv
$ python -m venv env
$ source env/bin/activate
Installing Requirements
Create a requirements.txt file with the following content.
//requirements.txt
torch==1.13.1
torchvision==0.14.1
Flask==2.2.2
pandas==1.5.2
tqdm==4.64.1
matplotlib==3.6.3
gunicorn==20.1.0
Then simply run the command pip install -r requirements.txt
to install the dependencies detailed in the requirements.txt file
Writing the code
Project structure
We follow the standard Flask project file structure. We have a static
folder that contains our static assets such as CSS and images. We also have a templates
directory that contains our HTML templates. Our main directory is app
and it contains two files, __init__.py
and main.py
. For our purpose, we leave __init__.py
empty and write the routes and other relevant code in main.py
.
app/
__init__.py
main.py
static/
/images
/resnet/
templates/
requirements.txt
run.sh
Importing dependencies
We import the essential dependencies.
# main.py
from torchvision.models import resnet50, ResNet50_Weights
from flask import Flask, request ,render_template, url_for
from werkzeug.utils import secure_filename
import urllib.request,io
from PIL import Image
import os
Writing the create_app function
In main.py
, we write a create_app
function that returns a Flask object.
# main.py
def create_app():
app = Flask('image-classifier')
...
...
return app
Initializing the Pytorch Model
For this tutorial, we use a pre-trained resnet50
model available from torchvision
for convenience. Feel free to replace this with your own model. We then call .eval()
to toggle the model to inference mode.
# main.py
def create_app():
app = Flask('image-classifier')
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights)
model.eval()
preprocess = weights.transforms()
For this model, we create a preprocess
object that is basically a pre-processing transformation step that comes with the resnet50
model. We will use this preprocess
object to modify the images we feed into the model.
Defining the routes
In this app, we have three routes, /gui
and /classify
. The /gui
route is where we present the graphical user interface to interact with the API. The /classify
route is for the classification API itself for which we can interact with it through a command line interface or a REST inspector application. The classify_gui
function contains the code for the GUI interaction with the classification API. It is slightly different from the classify
function as it contains an image upload feature.
def create_app():
...
...
@app.route('/classify')
def classify():
...
...
@app.route('/')
def gui():
...
...
@app.route('/', methods=['POST'])
def classify_gui():
...
...
return app
Defining the classification endpoint
Flask allows you to define the routes through the @app.route
decorator. In the classify
function, we implement what happens when a user performs a GET
request in the /classify
route. We nest this function in the create_app
function for convenience. This allows us to have access to the model
without passing it as parameters to classify()
.
# main.py
def create_app():
...
...
@app.route('/classify')
def classify():
url = io.BytesIO(urllib.request.urlopen(request.get_json()['url']).read())
img = Image.open(url)
batch = preprocess(img).unsqueeze(0)
prediction = model(batch).squeeze(0).softmax(0)
class_id = prediction.argmax().item()
score = prediction[class_id].item()
category_name = weights.meta["categories"][class_id]
return { 'class': category_name, 'confidence': float("{:.2f}".format(score)) }
In this particular function, we accept a json containing a url
value. We parse the url
from the request body and load it as a Pillow image. We then do some model-specific pre- and post- processing methods. The resnet50
model by default accepts a batch of images. Since we’re only passing it one image, we perform unsqueeze
to make it a batch of one image. After passing the batch to model
, we extract the prediction by applying the softmax
function and getting the index of the maximum prediction value.
Defining the GUI endpoint
The /
route will contain the graphical interface for users to interact with the model. Here is how it will look like.
We use render_template
function to return html.
@app.route('/')
def gui():
return render_template('gui.html')
We create the base.html
file. The gui.html
will inherit from this file
<!--base.html-->
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta http-equiv="X-UA-Compatible" content="IE=edge">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>ML API</title>
</head>
<body>
{% block container %}{% endblock %}
</body>
</html>
We create the gui.html
file.
<!--gui.html-->
{% extends 'base.html' %}
{% block container %}
<div class="home">
<h1>Image Classification</h1>
<h3>with Pytorch ResNet50</h3>
<h2>CLI</h2>
<table border="1">
<tr>
<td><strong>curl</strong></td>
<td>curl --request GET \
--url BASE_URL/classify \
--header 'Content-Type: application/json' \
--data '{
"url": IMG_URL
}'</td>
</tr>
<tr>
<td><strong>wget</strong></td>
<td>wget --quiet \
--method POST \
--header 'Content-Type: application/json' \
--body-data '{\n "url": IMG_URL \
--output-document \
- BASE_URL/classify</td>
</tr>
</table>
</div>
<h2>GUI</h2>
<form method="post" enctype="multipart/form-data" >
{{img_path}}
<input type="file" name="file" id="img_upload"><br><br>
<img id="image" height="400px" src={{img_url}}>
<br>
<input type="submit" value="Classify">
</form>
<br>
<div>{{result}}</div>
<script>
document.getElementById('img_upload').onchange = function ( ) {
var src = URL.createObjectURL(this.files[0])
document.getElementById('image').src = src
}
</script>
{% endblock %}
Then we craft the classify_gui
function. This function gets called when the user presses the classify button.
@app.route('/', methods=['POST'])
def classify_gui():
clear_dir('./static/images/')
if 'file' not in request.files:
return render_template('gui.html', msg='No file found')
file = request.files['file']
if file.filename == '':
return render_template('gui.html', msg='No file selected')
if file and allowed_file(file.filename):
filename = secure_filename(file.filename)
path = f'static/images/{filename}'
file.save(path)
img = Image.open(path)
batch = preprocess(img).unsqueeze(0)
prediction = model(batch).squeeze(0).softmax(0)
class_id = prediction.argmax().item()
score = prediction[class_id].item()
category_name = weights.meta["categories"][class_id]
return render_template(
'gui.html', result = f"{category_name}: {100 * score:.1f}%",
img_url = f"{url_for('static', filename=f'/images/{filename}')}"
)
Code Summary
Here is the whole code for main.py
# main.py
from torchvision.models import resnet50, ResNet50_Weights
from flask import Flask, request ,render_template, url_for
from werkzeug.utils import secure_filename
import urllib.request,io
from PIL import Image
import os
def create_app():
app = Flask('image-classifier')
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights)
model.eval()
preprocess = weights.transforms()
@app.route('/classify')
def classify():
url = io.BytesIO(urllib.request.urlopen(request.get_json()['url']).read())
img = Image.open(url)
batch = preprocess(img).unsqueeze(0)
prediction = model(batch).squeeze(0).softmax(0)
class_id = prediction.argmax().item()
score = prediction[class_id].item()
category_name = weights.meta["categories"][class_id]
return { 'class': category_name, 'confidence': float("{:.2f}".format(score)) }
@app.route('/')
def gui():
return render_template('gui.html')
@app.route('/', methods=['POST'])
def classify_gui():
clear_dir('./static/images/')
if 'file' not in request.files:
return render_template('gui.html', msg='No file found')
file = request.files['file']
if file.filename == '':
return render_template('gui.html', msg='No file selected')
if file and allowed_file(file.filename):
filename = secure_filename(file.filename)
path = f'static/images/{filename}'
file.save(path)
img = Image.open(path)
batch = preprocess(img).unsqueeze(0)
prediction = model(batch).squeeze(0).softmax(0)
class_id = prediction.argmax().item()
score = prediction[class_id].item()
category_name = weights.meta["categories"][class_id]
return render_template(
'gui.html', result = f"{category_name}: {100 * score:.1f}%",
img_url = f"{url_for('static', filename=f'/images/{filename}')}"
)
return app
def clear_dir(dir_path):
for file in os.listdir(dir_path):
os.remove(f'{dir_path}/{file}')
def allowed_file(filename):
return '.' in filename and filename.rsplit('.', 1)[1].lower() in {'png', 'jpg', 'jpeg'}
This code is also available on github.
Localhost testing
At this point, one can already deploy and use the app in our localhost. Just create a run.sh
file with the following commands. Setting the hostname is necessary as it allows any other computer besides the host to access our flask app.
# run.sh
source env/bin/activate
export FLASK_APP="app.main:create_app"
export FLASK_RUN_PORT=8080
export FLASK_APP_HOSTNAME="0.0.0.0"
flask run --host=$FLASK_APP_HOSTNAME
Then run the shell file
$ sh run.sh
To create a request, simply paste this sample request to another terminal window.
$ curl --request GET \
--url localhost:8080/classify \
--header 'Content-Type: application/json' \
--data '{
"url": "https://external-content.duckduckgo.com/iu/?u=https%3A%2F%2F2.bp.blogspot.com%2F-HB23AuZMkCc%2FUmFMY7DNM8I%2FAAAAAAAAA8Q%2FFD8D50slVJ4%2Fs1600%2FAlbatross-Bird-Pic.jpg&f=1&nofb=1&ipt=0ac75b04d5d023fa362e560745360a7aad5e15c90e2673d887ff40851e71d9f3&ipo=images"
}'
Or launch your favorite API inspector, set the content type header to application/json
and supply an object with the url
of your chosen image. Voila! You have now created a REST API for your machine learning model!!
In this image, we pass a url of an image and receive a prediction. The app being used here is Insomnia
Deploying to the Cloud
Deploying your API to your localhost is one thing. Deploying your API to the totality of the interwebs is a whole different matter! Suddenly you have to worry about a whole slew of complexities such as DNS, Inbound & Outbound rules, virtual machines and so much more! Good thing as there are a bajillion articles and tutorials on the internet just about that!
Creating EC2 Instance
Since there are a lot of EC2 tutorials we can simply follow this one on youtube, https://www.youtube.com/watch?v=oqHfiRzxunY. Make sure to select Ubuntu Server AMI, and t2.micro as they are available in the free-tier.
Modifying the security group and Inbound rules
We add an inbound rule to allow other computers to communicate with our instance. We select Custom TCP as the type and 8080 as the port range. We set the source to 0.0.0.0/0 to allow any device from any ip address.
Opening the port in the firewall
Run the following commands in the EC2 terminal to open port 8080.
$ sudo ufw allow 8080
$ sudo ufw enable
Getting your code into the EC2 instance
There are two methods that to get your code into the EC2 instance. One way way is through SCP which is a linux command that allows for file transfer to authenticated users. Another quick but dirty way is to create a public git repository, this way you don’t have to deal with linux complexities. In this article, we’ll simply use the repository route. We simply create a git repository containing the code, publish it, and clone the public repository on the EC2 instance.
Deploying
Run the following commands to download pip and install the necessary libraries.
$ sudo apt update
$ sudo apt install python3-pip
$ pip install -r requirements.txt
Then simply run run.sh
again and type in http://PUBLIC_IP_ADDRESS:8080
to any web browser. Use your EC2 instance’s public ip address. Remember to use http
instead of https
. Voila! You can now access your machine learning API in any part of the world.
Ensuring scalability and security (Part 2 Coming Soon)
We have deployed our Flask app but the Flask team themselves do not recommend deploying bare Flask apps to production. This is because Flask does not come with built-in support for load balancing, security, database management etc. That will be discussed in another article though as this article has gone too long.
Top comments (0)