DEV Community

Cover image for Integrating Machine Learning Model with Rest API in Flask
Jaydeep Dey
Jaydeep Dey

Posted on

Integrating Machine Learning Model with Rest API in Flask

In this blog, I will show how to use Flask to combine a Rest API and a machine learning model. Flask is a micro web framework for back-end. By integrating the machine learning model with a Rest API in Flask, we can deploy the model for use in a production environment.

The integration of a machine learning model with a Rest API allows for the deployment of the model in a production environment, making it accessible to external clients through a simple and easy-to-use interface. This method of deployment is efficient and scalable, making it a valuable option for any project involving the use of machine learning models.

Before you start,

Make sure you create an virtualenv in the project folder.

Here are the steps:

pip install virtualenv
Enter fullscreen mode Exit fullscreen mode

Then, activate the virtual environment by typing the following command

.\env\Scripts\activate.ps1
Enter fullscreen mode Exit fullscreen mode

Install the following packages using pip command

pip install flask flask_cors
Enter fullscreen mode Exit fullscreen mode

Make sure you have your models ready as .pkl file. For instance, I have dumped two models using pickle module as follows

import pickle

# tfidf and mnb are two machine learning models
pickle.dump(tfidf, open('vectorizer.pkl', 'wb'))
pickle.dump(mnb, open('model.pkl', 'wb'))
Enter fullscreen mode Exit fullscreen mode

Copy the .pkl files to the project directory and put in separate folder, name it as Models

Now, in the root directory, create a main.py file and paste this template:

from flask import Flask
app = Flask(__name__)

@app.route('/')
def hello_world():
    return 'Hello, World!'

if __name__ == "__main__":
    app.run(debug=True)
Enter fullscreen mode Exit fullscreen mode

Import the two .pkl files from the models directory

from flask import Flask
import pickle
# from waitress import serve
app = Flask(__name__)

# importing models
model = pickle.load(open('./model/model.pkl', 'rb'))
tfidf = pickle.load(open('./model/vectorizer.pkl', 'rb'))

@app.route('/')
def hello_world():
    return 'Hello, World!'

if __name__ == "__main__":
    app.run(debug=True)
Enter fullscreen mode Exit fullscreen mode

Now define a POST route for your REST API Endpoint which will send response as a prediction

from flask import Flask, request, jsonify
from Controllers import transformText
import pickle
# from waitress import serve
app = Flask(__name__)

# importing models
model = pickle.load(open('./model/model.pkl', 'rb'))
tfidf = pickle.load(open('./model/vectorizer.pkl', 'rb'))

@app.route('/')
def hello_world():
    return 'Hello, World!'

@app.route('/predict', methods=['POST'])
def predict():
    # gets json data from the request body
    data = request.json
    transformed_sms = transformText.transform_text(data['sentence'])
    vector_input = tfidf.transform([transformed_sms])
    result = model.predict(vector_input)[0]
    return jsonify({'result': int(result)})

if __name__ == "__main__":
    app.run(debug=True)
Enter fullscreen mode Exit fullscreen mode

jsonify is a Flask function that converts the JSON output of a Flask request into a response object that can be sent to a client.

NOTE: As my model demands Text Processing, I have defined a controller function, transformText in the Controllers directory in the same root folder.

Make sure you generate requirements.txt for your project using the following command

pip freeze > requirements.txt
Enter fullscreen mode Exit fullscreen mode

The project directory looks like this,

project

Congrats, your REST API is ready 🚀🥳

But wait, in your development environment you might encounter an CORS Error

To Resolve it, add the following lines to your code.

from flask_cors import CORS

# Handling CORS
cors = CORS(app)
app.config['CORS_HEADERS'] = 'Content-Type'
Enter fullscreen mode Exit fullscreen mode

And you are done 🤩

Top comments (0)