In this blog, I will show how to use Flask to combine a Rest API and a machine learning model. Flask is a micro web framework for back-end. By integrating the machine learning model with a Rest API in Flask, we can deploy the model for use in a production environment.
The integration of a machine learning model with a Rest API allows for the deployment of the model in a production environment, making it accessible to external clients through a simple and easy-to-use interface. This method of deployment is efficient and scalable, making it a valuable option for any project involving the use of machine learning models.
Before you start,
Make sure you create an virtualenv
in the project folder.
Here are the steps:
pip install virtualenv
Then, activate the virtual environment by typing the following command
.\env\Scripts\activate.ps1
Install the following packages using pip
command
pip install flask flask_cors
Make sure you have your models ready as .pkl
file. For instance, I have dumped two models using pickle module as follows
import pickle
# tfidf and mnb are two machine learning models
pickle.dump(tfidf, open('vectorizer.pkl', 'wb'))
pickle.dump(mnb, open('model.pkl', 'wb'))
Copy the .pkl
files to the project directory and put in separate folder, name it as Models
Now, in the root directory, create a main.py
file and paste this template:
from flask import Flask
app = Flask(__name__)
@app.route('/')
def hello_world():
return 'Hello, World!'
if __name__ == "__main__":
app.run(debug=True)
Import the two .pkl files from the models directory
from flask import Flask
import pickle
# from waitress import serve
app = Flask(__name__)
# importing models
model = pickle.load(open('./model/model.pkl', 'rb'))
tfidf = pickle.load(open('./model/vectorizer.pkl', 'rb'))
@app.route('/')
def hello_world():
return 'Hello, World!'
if __name__ == "__main__":
app.run(debug=True)
Now define a POST route for your REST API Endpoint which will send response as a prediction
from flask import Flask, request, jsonify
from Controllers import transformText
import pickle
# from waitress import serve
app = Flask(__name__)
# importing models
model = pickle.load(open('./model/model.pkl', 'rb'))
tfidf = pickle.load(open('./model/vectorizer.pkl', 'rb'))
@app.route('/')
def hello_world():
return 'Hello, World!'
@app.route('/predict', methods=['POST'])
def predict():
# gets json data from the request body
data = request.json
transformed_sms = transformText.transform_text(data['sentence'])
vector_input = tfidf.transform([transformed_sms])
result = model.predict(vector_input)[0]
return jsonify({'result': int(result)})
if __name__ == "__main__":
app.run(debug=True)
jsonify
is a Flask function that converts the JSON output of a Flask request into a response object that can be sent to a client.NOTE: As my model demands Text Processing, I have defined a controller function,
transformText
in the Controllers directory in the same root folder.
Make sure you generate requirements.txt for your project using the following command
pip freeze > requirements.txt
The project directory looks like this,
Congrats, your REST API is ready 🚀🥳
But wait, in your development environment you might encounter an CORS Error
To Resolve it, add the following lines to your code.
from flask_cors import CORS
# Handling CORS
cors = CORS(app)
app.config['CORS_HEADERS'] = 'Content-Type'
And you are done 🤩
Top comments (0)