Introduction
In this blog post I want to share a small application I developed that classifies images of hand written digits, together with the lessons learned while developing it. When it comes to machine learning, in the past I have mainly worked with text data. Pattern recognition on image data is new to me but I think it is a very useful skill.
The post is structured as follows. First we are introducing the concept of image classification and what makes it special compared to other problems such as text classification. The next section introduces a machine learning model called Convolutional Neural Network (CNN), which is commonly used in image classification. The third section show cases an example application which performs handwritten digit classification through a web interface. We are closing the post by summarizing the main findings and ideas.
The application is written in Scala, HTML, CSS, and JavaScript. However the concepts can be transferred to other languages as well. I am also trying to keep the mathematical details to a minimum, focusing on the necessary information for the reader to develop an intuition about the algorithms used. In case you are interested in getting a deeper understanding of the subject, I recommend to take a look at other tutorials, research papers or books.
Image Classification
Machine learning algorithms expect data to be represented in some numerical format that the computer can understand. When using probabilistic models, e.g., your data has to fit into the format expected by the distributions your model is using.
As an example consider a multinomial mixture model [1]. To utilize this type of model, you need to be able to convert your data into counts. In text this can be achieved by introducing a counting variable for each possible word of each cluster in each possible document. This model is very simple and works great for many use cases. However it has one big disadvantage: It discards a lot of information, e.g. term cooccurrences and position within the document.
For image data this problem is even greater. While you can still determine whether an email is spam by just looking at the word counts, recognizing images with cats is much harder when only counting the number of pixels having a specific color. While text data is 1-dimensional, i.e. a sequence of terms, images are at least 2-dimensional, i.e. a matrix of pixels, and contain a lot more information in the spatial relation of the pixels.
Luckily there are other models we can use that take spacial information into account. A very commonly used type of models are Convolutional Neural Networks (CNN). While research in this area is ongoing for some time now [2], the era of GPU based training lead to major break-throughs in terms of model performance in the recent years [3].
How do we represent a raw image in the computer? The smallest addressable element of a computer image is a pixel. Each pixel has a position and a color. We can represent the color in different forms. A commonly used scheme for colored images is red-blue-green (RBG). If we reserve 24 bit for each pixel, i.e. 8 bit for each of the three colors, we can encode 256 different shades of red, blue, and green, respectively. Combining them, allows us to represent around 16 million different colors.
In order to access the image information from within our code, we can store the pixels in a two dimensional array, i.e. a matrix. While it would be possible to combine all three color channels inside a single coordinate of this matrix, it is more efficient to store only a single number. This leaves us with a matrix for each channel, so that we can represent grey-scale images as matrices and colored images as 3-dimensional tensors. The following figure illustrates how this process would look for a 3×3 pixel image. Note that in real images colors will be mixed most of the time.
Now let's take a look how CNNs work and how we can use this image representation as input for a CNN based classifier.
Convolutional Neural Networks
Architecture
A neural network is a machine learning model which consists of connected layers of neurons. A neuron contains a number, the so called activation. Connections are assigned weights, which describes the strength of the signal to the connected neuron.
Input data is fed into the first layer, activating each input neuron to some extend. Based on the weights and an activation function the network determines which neurons from the next layer to activate and how strong the activation is going to be. This so called feedforward process is continued until the output neurons are activated. The architecture of a neural network has a huge influence on which data it can work with and its performance. The following figure illustrates a simple neural network with three layers.
CNNs are a special type of neural networks. They can be divided into two parts: A feature learning part and a classification part. Each part consists of one or multiple layers. Feature learning is typically done by combining two types of layers: Convolution layers and pooling layers. Classification is then performed based on the learned features through dense layers, also known as fully connected layers. Additionally there is an input layer, containing the image data, as well as an output layer, containing the different classes we are trying to predict.
The following figure illustrates a CNN with one convolution layer, one pooling layer, and one dense layer. The task is to predict whether the image depicts a cat. Layers that are in-between the input and output layer are also referred to as hidden layers as there state is not directly visible when treating the model as a black box.
Considering a single color channel, the input layer can either be the raw image matrix or a preprocessed one, e.g. cropped, resized, with scaled color values between 0 and 1, and so on. The output layer represents the weights of each possible class that are assigned by the last hidden layer. In the next subsection we want to take a closer look at the different hidden layer types.
Convolution Layers
A convolution layer is responsible for convolving a filter with the previous layer. If you are not familiar with 2-dimensional image filtering, you can take a look at the Image Filtering post from Machine Learning Guru. A filter can be viewed as a smaller image, i.e. a smaller matrix than the input, which is applied to a part of the input. If the part of the image matches what the filter expects, the output value will be high. Convolving the filter with the full input will yield another image that highlights certain aspects of the input.
Let's look at an example. The following figure shows the application of the Sobel-Feldman operator [4], also known as the Sobel edge detector filter, to our blue cat. To be precise we are applying two filters, one for horizontal and one for vertical edges. We then combine both results to obtain an image showing both, horizontal and vertical edges. The filter kernels are depicted in the center of the figure.
There are different configuration options when defining a convolution layer. Each convolution layer can have one or multiple filters. The convolution layer will then output an intermediate representations of the input for each filter. The more filters, the more diverse our image features can become.
In addition to the number of filter kernels, we can select a kernel size. The kernel size determines the locality of the filter, i.e. how many of the surrounding pixels are being taken into account when applying the filter. Secondly, we need to pick a stride value. The stride determines how many pixels we advance when convolving. A stride of 1 will move the filter across every pixel, while a stride of 2 will skip every second pixel.
The question is how do we pick the filters we want to use? The answer is, we don't. The great thing about neural networks is that they learn the features themselves based on the training data. The training procedure will be discussed a bit more in a later section. Now let's move to the second type of feature learning layers: Pooling layers.
Pooling Layers
Pooling layers are applied to down-sample the input. The goal is to reduce the computational complexity of the model and to avoid overfitting. The information loss is usually not that problematic as the exact location of the features is less important than the relation between them.
Pooling is implemented by applying a special filter function while choosing the kernel size and stride value in a way that the filter applications do not overlap. A commonly used technique is called max pooling. In max pooling we select the maximum value of the sub-region for our sub-sampled output. In the next figure we can see the result of applying 2×2 max-pooling to a 4×4 input matrix.
The following figure depicts the result of sub-sampling the output of the convolution layer twice. Note that sub-sampling reduces the image size, but I scaled the size up again to visualize the loss of information.
How can we use the derived features to predict a class? Let's find out by looking closer into how dense layers work.
Dense Layers
Dense layers connect every neuron from the previous layer to the next one. In the context of CNNs they form the classification part of the network. Neurons in the dense layers learn which features each class is composed of.
Dense layers are more complex in terms of parameter fitting than convolution layers. A filter with a 3×3 kernel from a convolution layer has 9 parameters independent of the number of input neurons. A fully connected layer of 16 neurons with 28×28 neurons on the previous layer already has 28×28×16 = 12,544 weights.
Now that we are more familiar with the different components of CNNs, you might wonder how to find the correct values for all parameters, i.e. the filter kernels and weights in the dense layers.
Training
Like all machine learning algorithms, training is done based on example inputs where the class label is known. An untrained CNN is initialized with random parameters. We can then feed training examples through the network and inspect the activation of the output neurons. Based on the expected activation, i.e. full activation of the neuron associated with the correct class and no activation of the rest, we can derive a cost function which captures how wrong the network was.
Then we can start to tune the parameters to reduce the cost. This is done starting from the output neurons, adjusting the parameters of each layer up to the input layer. This learning process is refered to as backpropagation. How do we know which parameter to increase and which to decrease, and how much?
I'm not going to go into too much mathematical detail here but you might remember from calculus that for some functions you can compute a derivative, telling you how the output of the function changes given a change in the input variable. The derivative represents the slope of the tangent of the function when plotted. If we computed this for our cost function it would tell us how each parameter influences the outcome towards our expected class label.
As our cost function has not only one but potentially thousands of input variables, (recall the number of weights already for a small dense layer), we can utilize the so called gradient. The gradient is a generalization of the derivative for multi-variable functions. To be precise we want to use the negative gradient, as we aim at reducing the cost. The negative gradient will tell us how we need to adjust the network parameters to better classify the training examples. This method is called gradient descent.
Computing the exact negative gradient for all our training examples is computationally infeasible most of the time. However, we can use a small trick: The input data is shuffled and grouped into small batches. We then compute the gradient only on this small subset, adjust the parameters of the network accordingly, and continue with the next batch. This so called stochastic gradient descent gives a good-enough approximation of the exact answer.
Keep in mind however that by descending the gradient we can only improve as much as the initial random parameters allow us. The network might not be able to improve without starting with completely different weights, getting stuck in a so called local minimum of the cost function. Several techniques exist to avoid getting stuck in a local minimum but they also have their disadvantages.
Now that we have our trained model we can feed images without a label and look at the output to determine the correct class. Next let's look at the "Hello World" example of image classification and the small app I built based on it.
Handwritten Digit Recongition
The Data
The "Hello World" of image classification is a seemingly simple, yet non-trivial problem of classifying handwritten digits. There is a rich training and test dataset is available online for free within the Modified National Institute of Standards and Technology database, widely known as MNIST database.
Each digit is available as a 28×28 pixel grey scale image. The following picture shows a few example images for each digit.
Application Architecture
In order to build something that one can use and play around with, my goal was to build a web application that allows you to draw a digit and get it classified. I am using Deeplearning4j (DL4J) to build, train, validate, and apply the model. It is an open source deep learning library for the JVM. Please find a small architecture diagram below.
The application is split into two parts:
- Training & Validation
- Prediction
The training and validation happens offline. It reads the data from a directory structure which already splits the data into training and test data, as well as containing the individual digits in their respective directories. After training is successful, the network gets serialized and persisted on the filesystem (model.zip
). The prediction API then loads the model on startup and uses it to serve incoming requests from the front end.
Before we are looking a bit into the individual components in detail, please note that the source code is available on GitHub and the app is online and can be tried out thanks to Heroku. I am only using a free tier so you might have to wait a bit when the application is used for the first time after a while as it lazily starts the server.
The Front End
The front end is a simple HTML 5 canvas plus a bit of JavaScript to send the data to the back end. It is heavily inspired by the Create a Drawing App with HTML 5 Canvas and JavaScript tutorial by William Malone. In case you cannot access the live version right now, you can check out a screen shot of the front end below.
It features a drawing canvas, a button to send the canvas content to the back end, a button to clear the canvas, and an output area for the classification result. The index.html
is not very complicated. Here are the HTML elements used:
<body>
<div id="canvasDiv"></div>
<div id="controls">
<button id="predictButton" type="button">Predict</button>
<button id="clearCanvasButton" type="button">Clear</button>
</div>
<div id="predictionResult">
</div>
</body>
We then add some CSS (app.css
) to the mix to make it look less ugly. The JavaScript code (app.js
) is basic jQuery, nothing fancy and very prototypical. It first builds up the canvas and defines the drawing functions. Prediction is done by sending the canvas content to the back end. Once the result arrives we are showing it in the output div
.
$('#predictButton').mousedown(function(e) {
canvas.toBlob(function(d) {
var fd = new FormData();
fd.append('image', d)
$.ajax({
type: "POST",
url: "predict",
data: fd,
contentType: false,
processData: false
}).done(function(o) {
$('#predictionResult').text(o)
});
});
});
The Back End
The back end (PredictAPI.scala
) is a small Akka HTTP web server. On startup we load the model from disk. We have to wrap the access in a synchronized block, as the default model implementation of DL4J is not thread safe.
val model = new SynchronizedClassifier(
ModelSerializer.restoreMultiLayerNetwork("model.zip")
)
There is a route for the static files, i.e. index.html
, app.js
, and app.css
, as well as one for receiving images of digits for prediction.
val route =
path("") {
getFromResource("static/index.html")
} ~
pathPrefix("static") {
getFromResourceDirectory("static")
} ~
path("predict") {
fileUpload("image") {
case (fileInfo, fileStream) =>
val in = fileStream.runWith(StreamConverters.asInputStream(3.seconds))
val img = invert(MnistLoader.fromStream(in))
complete(model.predict(img).toString)
}
}
For every incoming image we have to apply some basic transformations like resizing and scaling, which are implemented in the MnistLoad.fromStream
method. We are also inverting the image as the network is trained to classify white digits on black background.
The Model
The model used is a seven layer CNN, heavily inspired by the DL4J Code Example for CNNs. The hidden layers are two pairs of convolution-pooling layers, as well as one dense layer. It is trained using stochastic gradient descent with batches of 64 images. The test accuracy of the model is 98%.
The training and validation process is implemented in TrainMain.scala
. There you can also find the exact model configuration. I don't want to go into too much detail at this point but if you have any questions regarding the model architecture, feel free to drop a comment.
Deployment with Heroku
I chose to deploy the application with Heroku as it allows to quickly deploy applications publicly, has a free tier, and integrated very well within the development workflow. I am using the Heroku CLI.
For Scala projects built with SBT, Heroku will execute sbt stage
. This will produce a binary artifact of the app together with all library dependencies. The Procfile
specifies how to start the app. Here are the commands required to deploy to Heroku.
-
heroku login
(logging in to your Heroku account) -
heroku create
(initializing theheroku
remote) -
git push heroku master
(push changes, triggering a build) -
heroku open
(open the application URL in your browser)
Problems
If you tried the application you might have run into some weird output. In fact, there are multiple issues which might lead to misclassification of your drawn digit even though the model has 98% accuracy.
One factor is that the images are not centered. Although the combination of convolution layers and subsampling through pooling helps, I suspect that moving and resizing all digits to the center of the canvas would aid the performance. For optimal results, try drawing the image in the lower 2/3 of the canvas.
Additionally, the training data captures a certain style of hand writing common in the US. While in other parts of the world, the digit 1 consists of multiple lines, in the US people often write it as one line. This can lead to a 1, written differently, being classified as a 7. The following figure illustrates this.
Summary
In this post we have seen how CNNs can be used to classify image data. Using a combination of approximate optimization techniques, sub-sampling and filter application we are able to train a deep network that captures features of the input images well.
Using a bit of JavaScript, HTML and CSS you are able to develop a front end for drawing images to be classified. The back end can be implemented using an HTTP server like Akka HTTP in combination with a deep learning framework like DL4J.
We have also seen that the classification performance in the real world only matches the test accuracy if the real data corresponds to the training and test data used when building the model. It is crucial to monitor model performance during run time, adjusting or retraining the model periodically to keep the accuracy high.
References
- [1] Rigouste, L., Cappé, O. and Yvon, F., 2007. Inference and evaluation of the multinomial mixture model for text clustering. Information processing & management, 43(5), pp.1260-1280.
- [2] LeCun, Y., Bottou, L., Bengio, Y. and Haffner, P., 1998. Gradient-based learning applied to document recognition. Proceedings of the IEEE, 86(11), pp.2278-2324.
- [3] Ciregan, D., Meier, U. and Schmidhuber, J., 2012, June. Multi-column deep neural networks for image classification. In Computer vision and pattern recognition (CVPR), 2012 IEEE conference on (pp. 3642-3649). IEEE.
- [4] Sobel, I., Feldman, G., A 3x3 Isotropic Gradient Operator for Image Processing, presented at the Stanford Artificial Intelligence Project (SAIL) in 1968.
If you liked this post, you can support me on ko-fi.
Top comments (10)
Very well written! Loved the explanations for the pooling and batch normalization. Does
dl4j
take advantage of GPUs?I did a project on recognizing numbers on houses from the Google street view house numbers dataset. I used
keras
for training as it has ready-made network architectures for the famous papers likeresnet
,vgg16
andxception
. These are trained on datasets like imagenet or CIFAR. We have to replace the final layer (for multi-class classification) with our layer. Keras also has a image data generator where you can slightly rotate, shear, scale and blur images to avoid overfitting and increase robustness of the model.One issue I faced was false positives e.g recognizing a door handle as a '1' as the dataset itself just has labels for the digits. I had to create a separate 'negative' class and extract random patches from the training images. With that, the accuracy went up to 97%. One more thing I found was that pre-processing the images makes a big difference. In my project, mean subtraction, normalization and a light gaussian blur reduced the training time and increased the accuracy.
Hi Raunak,
glad you enjoyed the post :) I tried to make it not too theoretical but without some intuition about the math I find it hard to understand how to apply it.
DL4J uses ND4J under the hood for numerical computations on the tensors. ND4J supports native libraries for many different platforms. If you want to use NVIDIA GPUs you can simply use the nd4j-cuda-* dependency. I haven't tried it out, yet, though.
I haven't used
keras
, yet but I'm planning to check it out later. I also want to try a more sophisticated problem.I completely agree with your last point about things that the network has never seen before. With CNNs it's very important to pick the right training data and have well-labeled data. In my example I was only rescaling the colors to (0,1) but didn't do any other preprocessing steps. Do you know if, similar to the convolution effect, there are networks that can learn some parts of the preprocessing as well? That would be interesting.
Thanks for your feedback!
Not sure of networks learning the pre-processing as most of the image pipelines I have seen try a lot of hit-and-miss steps with regards to pre-processing. I have seen people try thresholding images, use gradient or edge images, use RGB vs grayscale vs HSL. I think there is a lot of variability in pre-processing which makes it difficult for a network to learn. This is one case where having knowledge of your specific set and some knowledge of computer vision helps, otherwise, we will require a very large number of training images. If we have a small number of images, we can 'augment' the dataset by using image data generators which slightly change images by rotating/resizing/blurring/distorting,etc.
There are also LSTM networks which have a concept of memory but they are used more for speech recognition and time series. I haven't worked with these yet.
Got LSTM on my To-Do list, already. Definately going to check them out!
Thanks Raunak!
Raunak
It works very well with GPUs. We use them all the time.
DL4J comes with an image pre-processing transform API that you can see in action here:
github.com/deeplearning4j/dl4j-exa...
DL4J also comes with many of the famous networks many of which actually come from Keras using the model import feature. Just add deeplearning4j-zoo to your project and use the TransferLearning class to edit the graph like so:
github.com/deeplearning4j/dl4j-exa...
Thanks Eduardo. The model import from
keras
looks very useful.One thing you want to be careful with is that DL4J models aren't threadsafe. You want to wrap them inside the ParallelInference wrapper which has a few knobs for maximizing performance. You can see how to use it from the unit tests:
github.com/deeplearning4j/deeplear...
Thanks for letting me know. I was aware that the models aren't thread safe that's why I wrapped it around a synchronized block (as mentioned in the post). Definately going to take a look into the
ParallelInference
wrapper. Thanks for the link!Could you please explain how our image stream can be converted to the image size and format that the model expects ?
I can certainly try. Would you be able to provide me with more details on your image stream? Are you also using DL4J?