DEV Community

Cover image for Implementation of Tensorflow Lite model on Android
Z. QIU
Z. QIU

Posted on

Implementation of Tensorflow Lite model on Android

Recently in some interview I have been asked about experience of implementing trained tensorflow models in android platform. I have tried one android project cloned from github which embedded a tflite model in it. However, I have not yet tried implementing my own model in an Android application. Thus I did such an exercise today and I successfully made my CNN model work on my Redmi Note 8 pro.

Alt Text

CNN model

Here is the code for training a cnn model with mnist data set. This model then is converted as tflite model and shall be implemented in Android application for recognizing hand-write digits.


import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = "2"
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics,models



(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
print(x_train[0,:,:])
## x_train.shape => (60000, 28, 28), y_train.shape => (60000,)
## x_test.shape => (10000, 28, 28), y_test.shape => (60000,)
x_train = tf.expand_dims(x_train, -1)
x_test = tf.expand_dims(x_test, -1)

yt = tf.squeeze(y_train)  

y_train = tf.squeeze(y_train)    
y_test = tf.squeeze(y_test)

print("Dataset info: ", x_train.shape, y_train.shape, x_test.shape, y_test.shape)

batch_size = 128

train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_db = train_db.shuffle(10000).batch(batch_size)

test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_db = test_db.batch(batch_size)

train_iter = iter(train_db)
sample = next(train_iter)
print(sample[0].shape, sample[1].shape)  

##  build a standard cnn model
model = models.Sequential()
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))

model.add(layers.Flatten())
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(10))

model.summary()


model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
             metrics=['accuracy'])

train_history = model.fit(train_db, epochs=10,          validation_data=test_db)

## once the model has been trained, convert it to tflite model
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

with open('qiu_mnist_model.tflite', 'wb') as f:
    f.write(tflite_model)

Enter fullscreen mode Exit fullscreen mode

Implementation in Android app

I refereed to this post for obtaining the original android project. I imported the kotlin version into my Android Studio. However, there were some bugs initially when I loaded my model into it.

My own model is located to asset repository:
Alt Text

The most important thing for this work is the following Gradle setting:
Alt Text

After about 15min of debugging and code modifications, I successfully made my model work.

Check out the video (there is still accuracy issue):

I will upload the android project src code to my github repo once I finish cleaning the code and improve the performance.

reference
  1. https://www.tensorflow.org/lite/performance/post_training_quantization

  2. https://margaretmz.medium.com/e2e-tfkeras-tflite-android-273acde6588

Top comments (0)