DEV Community

Cover image for Recognizing hand drawn Doodles using Deep Learning
Lars Wächter
Lars Wächter

Posted on


Recognizing hand drawn Doodles using Deep Learning

This post was originally published on my blog.

In November 2016 Google released an online game called "Quick, Draw!" that asks the player to draw an image of a prescribed object and then uses a neural network to guess what the drawing represents.
All in all there are 345 different objects the neural network can recognize.

Luckily, Google released the dataset they trained their neural network with, which includes more than 50 million by the players hand drawn images. So you can use this dataset to train your own neural network. And that's exactly what this article is about: we'll build a convolutional neural network to recognize hand drawn images using the Quick, Draw! dataset. Furthermore, we'll build a simple web app that allows the user to draw images and predict them using the network model later on.

The complete code and the trained model are available at GitHub. You can find the webapp demo on Heroku.

Neural Network

What we'll do:

  1. Generate, load & visualize the training data
  2. Design the network
  3. Train & export the model
  4. Convert the model to TFLite


For developing the convolutional neural network we'll use the following dependencies as listed in requirements.txt:

Enter fullscreen mode Exit fullscreen mode

Tip: create a new virtual environment for that.

Install the dependencies using the following command.

pip3 install -r requirements.txt
Enter fullscreen mode Exit fullscreen mode

Let's have a look at the dataset before writing the actual code.


You can find the complete dataset at Google Cloud Platform, which contains more than 50 million images of 345 different categories. A list of all included categories is avaiable here.

A single image is represented as follows in the Quick, Draw! dataset:

  "timestamp":"2017-03-01 20:41:36.70725 UTC",
Enter fullscreen mode Exit fullscreen mode

The following properties are important for us:

  • word (the image's category)
  • recognized (whether the drawing was recognized by Google's AI)
  • drawing (an array representing the vector drawing)

The actual image in "drawing" is a multi-dimensional array including the pixel coordinates of each single stroke:

  [  // First stroke
    [x0, x1, x2, x3, ...],
    [y0, y1, y2, y3, ...],
    [t0, t1, t2, t3, ...]
  [  // Second stroke
    [x0, x1, x2, x3, ...],
    [y0, y1, y2, y3, ...],
    [t0, t1, t2, t3, ...]
  ... // Additional strokes
Enter fullscreen mode Exit fullscreen mode


In order to train the neural network we create our own slightly modified dataset from Google's one. For downloading and accessing the one from Google Cloud Platform we use a Python package called quickdraw.

The following steps are required to create our own dataset:

  1. Load 1200 training images for each class from the cloud storage
  2. Resize them to 28x28 pixels
  3. Save them as PNG
image_size = (28, 28)

def generate_class_images(name, max_drawings, recognized):
    directory = Path("dataset/" + name)

    if not directory.exists():

    images = QuickDrawDataGroup(name, max_drawings=max_drawings, recognized=recognized)
    for img in images.drawings:
        filename = directory.as_posix() + "/" + str(img.key_id) + ".png"

for label in QuickDrawData().drawing_names:
    generate_class_images(label, max_drawings=1200, recognized=True)
Enter fullscreen mode Exit fullscreen mode

Setting recognized=True ensures that only images that have been recognized by Google's AI are loaded.

After the generation is finished there should be a directory structure that looks like the following.
Each class has its own subdirectory including 1200 images:

└── dataset
|   ├── aircraft carrier
|   │   ├── 4504134474530816.png
|   │   ├── 4506833509154816.png
|   │   ├── ...
|   ├── airplane
|   │   ├── 4508382553702400.png
|   │   ├── 4508818253807616.png
|   │   ├── ...
|   ├── ...
Enter fullscreen mode Exit fullscreen mode

In total there should be 414.000 images (345 * 1200).


Now we can load the images using Keras image_dataset_from_directory function and split them into a training and validation set. The batch size is set to 32.

batch_size = 32

train_ds = image_dataset_from_directory(

val_ds = image_dataset_from_directory(
Enter fullscreen mode Exit fullscreen mode

Using a 80/20 split we end up having 331.200 training and 82.800 validation images.


Next, let's visualize some random training images using matplotlib:

plt.figure(figsize=(8, 8))
for images, labels in train_ds.take(1):
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        data = images[i].numpy().astype("uint8")
        plt.imshow(data, cmap='gray', vmin=0, vmax=255)
Enter fullscreen mode Exit fullscreen mode

What outputs:



In the next step we design the convolutional neural network. Therefore, we make use of the following 7 Keras layers:

There are 345 classes in total. The input shape is (28, 28, 1) since all images have a size of 28x28 pixel and 1 color channel (grayscale).

n_classes = 345
input_shape = (28, 28, 1)

model = Sequential([
    Rescaling(1. / 255, input_shape=input_shape),

    Conv2D(6, kernel_size=(3, 3), padding="same", activation="relu"),
    Conv2D(8, kernel_size=(3, 3), padding="same", activation="relu"),
    Conv2D(10, kernel_size=(3, 3), padding="same", activation="relu"),
    MaxPooling2D(pool_size=(2, 2)),


    Dense(700, activation="relu"),

    Dense(500, activation="relu"),

    Dense(400, activation="relu"),

    Dense(n_classes, activation="softmax")
Enter fullscreen mode Exit fullscreen mode

Moreover, here's a summary of the model. In total the modal has 2,068,019 parameters. 2,065,597 of them are trainable.

Layer (type)                 Output Shape              Param #
rescaling (Rescaling)        (None, 28, 28, 1)         0
batch_normalization (BatchNo (None, 28, 28, 1)         4
conv2d (Conv2D)              (None, 28, 28, 6)         60
conv2d_1 (Conv2D)            (None, 28, 28, 8)         440
conv2d_2 (Conv2D)            (None, 28, 28, 10)        730
batch_normalization_1 (Batch (None, 28, 28, 10)        40
max_pooling2d (MaxPooling2D) (None, 14, 14, 10)        0
flatten (Flatten)            (None, 1960)              0
dense (Dense)                (None, 700)               1372700
batch_normalization_2 (Batch (None, 700)               2800
dropout (Dropout)            (None, 700)               0
dense_1 (Dense)              (None, 500)               350500
batch_normalization_3 (Batch (None, 500)               2000
dropout_1 (Dropout)          (None, 500)               0
dense_2 (Dense)              (None, 400)               200400
dropout_2 (Dropout)          (None, 400)               0
dense_3 (Dense)              (None, 345)               138345
Total params: 2,068,019
Trainable params: 2,065,597
Non-trainable params: 2,422
Enter fullscreen mode Exit fullscreen mode


We'll train the neural network for 14 epochs. At the end of the training the resulting Keras model is saved to the models directory. Additionally, TensorBoard
helps us to visualize the training process.

epochs = 14

logdir = os.path.join("logs","%Y%m%d-%H%M%S"))
tensorboard_callback = TensorBoard(logdir, histogram_freq=1)
)'./models/model_' +"%Y%m%d-%H%M%S"))
Enter fullscreen mode Exit fullscreen mode

After 14 epochs of training the network has a validation accuracy of 61.15%, what's not that bad for 345 categories. Especially because there a similarities between some. I'm sure there are still some things you can improve to get an even better score.



Last but not least since the web application requires a TFLite model, we have to convert the Keras model as described here.

# Convert the model
converter = tf.lite.TFLiteConverter.from_saved_model("models/<Model>") # path to the SavedModel directory
tflite_model = converter.convert()

# Save the model.
with open('model.tflite', 'wb') as f:
Enter fullscreen mode Exit fullscreen mode


Next we'll create the web application which is a simple FastApi server that hosts a single static HTML page where the user can draw a canvas
and the predicted labels are output as a pie chart with their probabilities. The REST API includes a single POST endpoint which is used for transforming the canvas.

Make sure to checkout the live demo.

Webapp Preview


For developing the web application we’ll use the following dependencies as listed in requirements.txt:

Enter fullscreen mode Exit fullscreen mode


The backend is required to resize the canvas to 28x28 pixel as our training images and to crop its content square and remove blank space.

Instead of using Python you can accomplish the same using Tensorflow's resizeBilinear function. However, I have had bad experiences with this function.
Using it causes a huge quality loss of the image with unwanted color effects. That's the reason why I'm using Pillow on the backend side.

The endpoint /transform expects an image's strokes and the bounding box for cropping it. Using these two parameters we call transform_img which draws the image using it strokes, resize it and crops it.
The resulting image is returned from the endpoint.

class ImageData(BaseModel):
    strokes: list
    box: list

app = FastAPI()"/transform")
async def transform(image_data: ImageData):
    filepath = "./images/" + str(uuid4()) + ".png"
    img = transform_img(image_data.strokes,

    return FileResponse(filepath, background=BackgroundTask(remove, path=filepath))

app.mount("/", StaticFiles(directory="static", html=True), name="static")

def transform_img(strokes, box):
    # Calc cropped image size
    width = box[2] - box[0]
    height = box[3] - box[1]

    image ="RGB", (width, height), color=(255, 255, 255))
    image_draw = ImageDraw.Draw(image)

    for stroke in strokes:
        positions = []
        for i in range(0, len(stroke[0])):
            positions.append((stroke[0][i], stroke[1][i]))
        image_draw.line(positions, fill=(0, 0, 0), width=3)

    return image.resize(size=(28, 28))
Enter fullscreen mode Exit fullscreen mode

The final image looks as follows:

Prepreocessed Image

It has a resolution of 28x28 pixel and it's cropped.


Let's continue with the frontend part of our web application. Here, we'll need a drawing area which allows the user to draw a canvas.
p5.js is a great library for such a use case.

NOTE: I will not cover each line of code here, only the important ones.
As I mentioned above you can finde the complete code here.

Model & Labels

First of all, we load our previous trained TFLite model using tfjs:

const loadModel = async () => {
  console.log("Model loading...")

  model = await tflite.loadTFLiteModel("./models/model.tflite")
  model.predict(tf.zeros([1, 28, 28, 1])) // warmup

  console.log(`Model loaded! (${LABELS.length} classes)`)
Enter fullscreen mode Exit fullscreen mode

LABELS is an array that contains all 345 image categories. For reasons of space I placed them in a separate file.
The order of the its elements is really important, don't change it! Otherwise your model will make wrong predictions.


Setup p5.js as follows:

const WIDTH = 500
const HEIGHT = 500

function setup() {
  createCanvas(WIDTH, HEIGHT)
Enter fullscreen mode Exit fullscreen mode

Handling mouse movement and click inside the canvas:

function mouseDown() {
  clicked = true
  mousePosition = [mouseX, mouseY]

// Check whether mouse position is within canvas
function mouseMoved() {
  if (clicked && inRange(mouseX, 0, WIDTH) && inRange(mouseY, 0, HEIGHT)) {

    line(mouseX, mouseY, mousePosition[0], mousePosition[1])
    mousePosition = [mouseX, mouseY]

function mouseReleased() {
  if (strokePixels[0].length) {
    strokePixels = [[], []]
  clicked = false
Enter fullscreen mode Exit fullscreen mode

When the mouse is clicked and moved its x/y coordinates are collected in strokePixels.
So the array contains all x and y pixels of the current drawn stroke:

  [x1, x2, ..., xn],
  [y1, y2, ..., yn]
Enter fullscreen mode Exit fullscreen mode

When the mouse is released, the "current stroke" is finished and added to the imageStrokes array which
contains all drawn strokes of the canvas. In fact it's an array of strokePixels:

  // First stroke
  [[x0, x1, x2, x3, ...], [y0, y1, y2, y3, ...]],

  // Second stroke
  [[x0, x1, x2, x3, ...], [y0, y1, y2, y3, ...]],
Enter fullscreen mode Exit fullscreen mode


Before predicting the label we have to preprocess the canvas using our /transform endpoint and Tensorflow.js. Therefore, we use the imageStrokes
array that contains all the canvas' strokes:

const preprocess = async cb => {
  const { min, max } = getBoundingBox()

  // Resize to 28x28 pixel & crop
  const imageBlob = await fetch("/transform", {
    method: "POST",
    headers: {
      "Content-Type": "application/json",
    redirect: "follow",
    referrerPolicy: "no-referrer",
    body: JSON.stringify({
      strokes: imageStrokes,
      box: [min.x, min.y, max.x, max.y],
  }).then(response => response.blob())

  const img = new Image(28, 28)
  img.src = URL.createObjectURL(imageBlob)

  img.onload = () => {
    const tensor = tf.tidy(() =>
        .fromPixels(img, 1)
Enter fullscreen mode Exit fullscreen mode

The function getBoundingBox calculates the minimum / maximum x and y coordinates of the drawing inside the canvas. Those values are used to crop the canvas on the backend side and remove the white background.


When making predictions we use our model and the tensor returned from preprocess. Afterwards, we select the top 3 predictions and output their probabilities with a pie chart.

const predict = async () => {
  if (!imageStrokes.length) return
  if (!LABELS.length) throw new Error("No labels found!")

  preprocess(tensor => {
    const predictions = model.predict(tensor).dataSync()

    const top3 = Array.from(predictions)
      .map((p, i) => ({
        probability: p,
        className: LABELS[i],
        index: i,
      .sort((a, b) => b.probability - a.probability)
      .slice(0, 3)

Enter fullscreen mode Exit fullscreen mode

That's it! You just created an application that recognizes hand drawn images using Deep Learning. Show the app to your friends and family. I'm sure they'll be quite impressed.

Top comments (0)