loading...

A short summary of the Union-Find algorithm

bobfang1992 profile image Bob Fang Updated on ・7 min read

First I want to point to you to this MIT lecture note that explained this problem and solution very well theoretically. The first half of this note is just a summary of the MIT note. The second part will be my C++ implementation, benchmarks and the third part will be example questions I took from Leetcode with my solutions.

I will not focus much on the analysis part here, as most of the optimisations here do not affect the big O of the algorithm. Also, the MIT notes provide a detailed discussion about the running time analysis.

Problem Statement

We want a data structure that supports these three kinds of operations:

  • make-set(u), u is an element and is not in any set yet. make-set marks it as a set with only one element, itself.

  • union(u, v), merge the two sets element u and v are in. If either u or v is not in a set already in a set then add it to the other element's set

  • find(u), find which set element u is in.

High-Level Design

Linked List Solution

This is easy to think of, head of the list is the representative element of a set. find is just finding out the head of an element. union is implemented as merging the two linked list. Both operations are O(n).

Forest of Trees Solution

Another easy one to think of, O(lg(n)) find time and O(lg(n)) union time.

Union by Rank

So as we can see the tree height is the determining factor of the algorithm's run time for the tree solution. So how could we optimise it?

One observation we can make is that if we have two trees, say tree(u) and tree(v) (tree(u) means the whole tree element u is in). If we merge the latter into the former, the new tree's height will be max(height(tree(u)), height(tree(v)) + 1). We always want to merge the smaller into the larger tree instead of the other way around (tree height matters!). So for each tree, we need to record the height of it, called rank. And the union algorithm has to take account the tree height when you really do the merge.

A pseudo-code description of union will be:

Union(u, v):
ut <- Find(u)
vt <- Find(v)

if ut.rank = vt.rank then
    ut.rank <- ut.rank + 1
    vt.parent <- ut

else if u.rank > v.rank then
    vt.parent <- ut
else 
    ut.parent <- vt

One to notice here. If the algorithm does not allow other optimisation rather than "union by find" then rank here is the actual height of the tree. But if we allow the path compression optimisation below then it will just be an upper bound of the tree height.

Path Compression

Another trick to make the trees shorter is when you do find you can link the nodes along the path to the root, thus make the tree shorter. I will omit the detailed algorithm described here, as the idea is rather simple.

C++ Implemenation

First Iteration

So I went straight to the tree solution. Here is a snippet of the first implementation.

template <typename T>
class UFTree : public std::enable_shared_from_this<UFTree<T>> {
  size_t d_rank;
  T d_element;
  std::shared_ptr<UFTree<T>> d_parent;

public:
  UFTree(const T &e) : d_rank(0), d_element(e), d_parent() {}
  UFTree(T &&e) : d_rank(0), d_element(std::move(e)), d_parent() {}

  T value() const { return d_element; }
  T &ref() { return d_element; }
  T &ref() const { return d_element; }
  size_t rank() const { return d_rank; }

  std::shared_ptr<UFTree<T>> parent() const { return d_parent; }
  void setParent(std::shared_ptr<UFTree<T>> p) {
    if (p != this->shared_from_this()) {
      d_parent = p;
    }
  }

  std::shared_ptr<UFTree<T>> root() {
    std::shared_ptr<UFTree<T>> rtn =
        this->shared_from_this(); // default to return the node itself;
    while (rtn->parent()) {
      rtn = rtn->parent();
    }
    return rtn;
  }

  T findSet() { return root()->d_element; }

  std::shared_ptr<UFTree<T>> unionSet(UFTree &other) {
    std::shared_ptr<UFTree<T>> thisP = root();
    std::shared_ptr<UFTree<T>> thatP = other.root();
    thisP->setParent(thatP);
    return thatP;
  }
};

There is a number of things to look out for here.

The first thing to notice here is that the design is very rough, it uses dynamic storage everywhere and expect the user to use them as well. And if you look closely I used enable_shared_from_this here which means you cannot even have something like this:

UFTree<int> t1(1);
UFTree<int> t2(2);
//t1 and t2 are locally allocated
//So they cannot have shared_ptr pointing to them
t1.unionSet(t2); 

Instead you have to use shared_ptr all the way, e.g. auto t1 = std::make_shared<Tree<int>>(1);.

One suggestion suggest a factory pattern here. But what I choose to do in the end is to add an extra layer on top of the Tree class, wrapping around the shared_ptr. The code is like:

template <typename T> class UFElement {
  std::shared_ptr<UFTree<T>> d_tree;

public:
  UFElement() {}
  // UFElement(T e) : d_tree(std::make_shared<UFTree<T>>(e)) {}
  UFElement(const T &e) : d_tree(std::make_shared<UFTree<T>>(e)) {}
  UFElement(T &&e) : d_tree(std::make_shared<UFTree<T>>(std::move(e))) {}
  UFElement(std::shared_ptr<UFTree<T>> t) : d_tree(t) {}

  T value() const { return d_tree->value(); }

  UFElement root() { return UFElement(d_tree->root()); }
  T findSet() { return root().value(); }
  UFElement unionSet(UFElement e) {
    d_tree->unionSet(*e.d_tree.get());
    return e.root();
  }
  friend bool operator==(const UFElement &element, const UFElement &other) {
    return element.d_tree == other.d_tree;
  }
};

Of course, there is a bit overhead here. From what I measured with a very rough benchmark using UFElement here instead of UFTree directly adds a 20% overhead in constructing the object.

Before we move on to optimisations mentioned before, let's see some examples using this API.

Example Problems

Graph Valid Tree

The problem statement is easy to understand, we want to find out that given a number of nodes and a list of edges, the graph described would form a tree or not.

Note that the definition of a tree is a graph without a cycle, by understanding this we understand how to solve the problem. When we are unioning two nodes in a graph, if they belong to the same set, then by unioning them we are creating a cycle. So to solve this problem we just need union all the nodes according to the edges list and test if the nodes being added are of the same set.

However there is one more thing to consider here, for example, if n is 4 here and edges is [(0,1),(2,3)]. Although we passed the check mentioned before, in the end we have two trees instead of one. So after joining all the edges, we have to check how many sets are there. Return true if only all the nodes are in the same set.

Let's look at the solution here, in this version I did not use the union find data structure at all, I just used a plain vector.

class PlainSolution {
public:
  bool validTree(int n, vector<pair<int, int>> &edges) {
    vector<int> uf(n);
    for (int i = 0; i < n; i++) {
      uf[i] = i;
    }

    for (auto p : edges) {
      int f = p.first;
      int s = p.second;

      if (uf[f] == uf[s]) {
        return false;
      } else {
        int temp = uf[s];
        uf[s] = uf[f];
        for (int i = 0; i < n; i++) {
          if (uf[i] == temp) {
            uf[i] = uf[f];
          }
        }
      }
    }

    int g = -1;
    for (int i = 0; i < n; i++) {
      if (g == -1)
        g = uf[i];
      else {
        if (uf[i] != g)
          return false;
      }
    }
    return true;
  }
};

And then a version that uses the data structure we built.

#include <disjoint_set.h>

namespace class Solution {
public:
  bool validTree(int n, vector<pair<int, int>> &edges) {
    vector<UFElement<int>> v;
    for (int i = 0; i < n; i++) {
      v.push_back(i);
    }
    for (auto p : edges) {
      int t1 = v[p.first].findSet();
      int t2 = v[p.second].findSet();

      if (t1 == t2)
        return false;
      v[p.first].unionSet(v[p.second]);
    }

    for (int i = 0; i < v.size() - 1; i++) {
      if (v[i].findSet() != v[i + 1].findSet())
        return false;
    }
    return true;
  }
}; // namespace classSolution

The vector solution is faster than the one that used our customised data structure. Some google benchmark statics show:

benchmark

The two algorithm works almost exactly the same way so I tried to optimise the second solution a little bit to catch up with the first one (the plain one in the above screenshot). The first trick is very simple, v.reserve(n) reduced the running time from 7000-ish to around 6000.

It's time to add path compression and union rank. The code can be found here. After the optimisation is done the running time dropped to 4500-ish on my MacBook pro. But on the Leetcode website, I see a larger jump in performance. The one DisjointSet solution outperforms the plain vector solution there (13ms vs 19ms). I suspect it is because on my local machine my test case is relatively small, but on the wesite they have larger test cases.

So to test this theory I want to generate a larger random tree to test the performance of my algorithm. I found this algorithm for generating a random tree. I quickly implemented the algorithm and then used it to test out the big O of the two solutions. Here is the result:

alt

Clearly when n is small enough the plain vector solution wins, but as n becomes large the performance got worse and worse, notice when n is 8192, the DisjointSet solution is much quicker than the plain vector solution, and the RMS (variance if you care) is much smaller for the DisjointSet soution, meaning it consistently performs with the O(n) line, while the vector solution is not that reliable.

alt

A performance plot is provided above.

Cool, theory proved!

Now we can move on to some more actual Leetcode problem-solving.

The Second Problem

Number of Islands

TODO: write up my findings here....

The Third Problem

Longest Consecutive Sequence

TODO: write up my findings here...

The problem I have here is that it is very hard to optimise for Number of Islands and Longest Consecutive Sequence. I am not sure why. Needs a bit more digging.

Discussion

pic
Editor guide
Collapse
arj profile image
arj

For the last problem I'd suggest using a helper/factor as you mentioned:
auto t1 = make_tree(1);
auto t2 = make_tree(2);

which creates you a shared_ptr tree.

Implementation (untested code):

ˋˋˋ
template
std::shared_ptr> make_tree(const T& t) {
return std::make_shared>(t);
}
ˋˋˋ