DEV Community

Ferdo Vukojević
Ferdo Vukojević

Posted on • Updated on

Data Structures: QuadTree

Hello Community! 👋

I was thinking about starting a series on dev.to called Wednesday Data Structures (yes, my names are very original). And in this series I will try to post a new data structure every Wednesday, go in depth about how it works and when it is used. Hopefully I will have some cool examples and images to showcase as well.

There are so many cool structures out there, from Trees, Heaps, Tries, Linked Lists, Doubly Linked Lists, and so many more.

QuadTree

For first in hopefully many data structures in this series, I decided to talk about Quad Trees.

1) What's QuadTree?

Quadtrees are trees used to efficiently store data of points on a two-dimensional space. In this tree, each node has exactly four children. (so no children, or 4)
We can construct a quadtree from a two-dimensional area using the following steps:

  • Divide the current two dimensional space into four boxes.
  • If a box contains one or more points in it, create a child object, storing in it the two dimensional space of the box
  • If a box does not contain any points, do not create a child for it
  • Recurse for each of the children.

2) When's it used?

So let's say you own a car dealership company. Your main focus is to sell as many cars as you can. And let's say you don't want to open a shop in a place with a lot of other car dealerships, because it may lower your business. Your problem is very simple:

  1. Go over a place you want to start your business
  2. See where the least amount of car dealerships are
  3. Open your shop there, and start selling!

Can you notice how QuadTrees can be good here? Well think about it, if we say that every point on a 2D map points to a car dealership, we can keep dividing our quadtree until we reach a smaller areas where we see no points at all, or at least lower than the threshold we give it -> and there we start building!

That was just one very simple usage of quadtrees, more advanced usages, applied in real life would be:

1.) Image compression

Each node contains the average colour of each of its children. The deeper you traverse in the tree, the more the detail of the image.

2.) Searching a 2D Area

Kinda the example I gave above. For instance, if you wanted to find the closest point to given coordinates, you can do it using quadtrees.

3.) Collision Detection

A Brute Force algorithm for Collision Detection would take O(n²) time. (Too slow!)

Creating a quadtree allows us to analyse only adjacent squares, reducing the number of comparisons greatly.

3) Implementation

Now the fun part, implementation. There are many implementations out there, and I wanted to create one by myself. It might not be the most optimized one, but it makes a lot of sense to me and hopefully you will also be able to read it.

Just for reference, I will be using Python for this, as I feel matplotlib is the easiest for representing what I want to achieve in the end

Let's start looking at 3 classes we will need:

Point

Well, we said we want to keep points on a map. And what does a point has to have? Coordinates. So let's create it!

point.py

class Point:
    def __init__(self, x, y, data):
        self.data = data
        self.x = x
        self.y = y

    def __str__(self):
        return 'P({:.2f}, {:.2f})'.format(self.x, self.y)

Node

We said in the beginning. Each node of QuadTree has exactly 4 or no children. So we need a node class to represent each 2D square in QTree.

from classes.point import Point


class Node:
    def __init__(self, x, y, w, h, points):
        self.x = x
        self.y = y
        self.width = w
        self.height = h
        self.points = points
        self.children = []

    def set_points(self, points):
        self.points = points

    def add_point(self, x, y):
        self.points.append(Point(x, y))

    def get_width(self):
        return self.width

    def get_height(self):
        return self.height

    def get_points(self):
        return self.points

Because we are going to plot it with matplotlib in the end, having the coordinates and width and height is pretty straightforward.

Also we want to keep track on how many points (think car dealerships) every node has so we can stop splitting in into 4 when we reach our goal.

QuadTree

Our main class where all the "magic" is taking place.

from classes.node import Node
from matplotlib import pyplot as plt
from matplotlib import patches


class QuadTree:
    def __init__(self, threshold):
        self.threshold = threshold
        self.root = Node(0, 0, 10, 10, None)

    def add_point(self, x, y):
        self.root.add_point(x, y)

    def get_points(self):
        self.root.get_points()

    def subdivide(self):
        recursive_subdivide(self.root, self.threshold)

    def graph(self):
        fig = plt.figure(figsize=(12, 8))
        plt.title("Quadtree")
        ax = fig.add_subplot(111)
        c = find_children(self.root)
        print(f"Number of segments: {len(c)}")
        areas = set()
        for el in c:
            areas.add(el.width * el.height)
        print(f"Minimum segment area: {min(areas)}")
        for n in c:
            ax.add_patch(patches.Rectangle((n.x, n.y), n.width, n.height, fill=False))
        x = [point.x for point in self.root.points]
        y = [point.y for point in self.root.points]
        c = [point.data['color'] for point in self.root.points]
        red_patch = patches.Patch(color='red', label='Gas Station')
        blue_patch = patches.Patch(color='blue', label='Police')
        yellow_patch = patches.Patch(color='yellow', label='Hospital')
        plt.legend(handles=[red_patch, blue_patch, yellow_patch])
        plt.scatter(x, y, c=c)
        plt.show()
        return


def recursive_subdivide(node, k):
    if len(node.points) <= k:
        return

    w_ = float(node.width / 2)
    h_ = float(node.height / 2)

    p = contains(node.x, node.y, w_, h_, node.points)
    x1 = Node(node.x, node.y, w_, h_, p)
    recursive_subdivide(x1, k)

    p = contains(node.x, node.y + h_, w_, h_, node.points)
    x2 = Node(node.x, node.y + h_, w_, h_, p)
    recursive_subdivide(x2, k)

    p = contains(node.x + w_, node.y, w_, h_, node.points)
    x3 = Node(node.x + w_, node.y, w_, h_, p)
    recursive_subdivide(x3, k)

    p = contains(node.x + w_, node.y + h_, w_, h_, node.points)
    x4 = Node(node.x + w_, node.y + h_, w_, h_, p)
    recursive_subdivide(x4, k)

    node.children = [x1, x2, x3, x4]


def contains(x, y, w, h, points):
    pts = []
    for point in points:
        if x <= point.x <= x + w and y <= point.y <= y + h:
            pts.append(point)
    return pts


def find_children(node):
    if not node.children:
        return [node]
    else:
        children = []
        for child in node.children:
            children += (find_children(child))
    return children

As there is quite a few things to discuss let's try to explain it.

The class takes 1 parameter called threshold. This is simply the number of threshold we have to be under to stop splitting the qTree.

add_point and get_point are pretty straightforwards. Graph is the main one being called where it plots the whole class. It will go through each children with find_children helper method and plot all of them on the same canvas. Children are just the children of the current node (again, every node has 4 or 0). I also added different markers, so red points to gas stations, blue to police and red to yellow to hospital just for fun.

Subdivide method will keep splitting our QTree until we reach the nodes where the number of points is under the threshold. The logic there is pretty simple -> Count the points and if bigger then threshold, keep calling yourself recursively.

And viola! In the end we just call everything we created in my main start.py file.

from classes.quadTree import QuadTree
from classes.point import Point
from numpy import random


def random_point_data(items):
    return random.choice(items)


if __name__ == '__main__':
    pointItems = [
        {
            'name': 'gas_station',
            'color': 'red'
        },
        {
            'name': 'hospital',
            'color': 'yellow'
        },
        {
            'name': 'police',
            'color': 'blue'
        }
    ]
    points = [Point(random.uniform(0, 10), random.uniform(0, 10), random_point_data(pointItems)) for x in range(1000)]
    quadTree = QuadTree(3)
    quadTree.root.set_points(points)
    quadTree.subdivide()
    quadTree.graph()

We generate the random 1000 points and start subdividing the QuadTree. Here I set my threshold to be equal to 3.

What I get in the end looks like this:

Alt Text

My image may look different than yours, because points are randomly generated every time so keep that in mind.

Conclusion

And that's it for episode 1. Hope you found it useful or at least not so boring to read.

I will the git repo down below if you want to check it out. Feel free to contact me on LinkedIn if you ever want to chat!

Github Repo: https://github.com/fvukojevic/QuadTree-Python

Top comments (0)