DEV Community

Isabella Tromba for Tangram

Posted on • Originally published at tangram.dev

Writing the fastest GBDT libary in Rust

In this post, we will go over how we optimized our Gradient Boosted Decision Tree library. This is based on a talk that we gave at RustConf 2021: Writing the Fastest Gradient Boosted Decision Tree Library in Rust. The code is available on GitHub.

The content of this post is organized into following sections:

  1. What are GBDTs.
  2. Use Rayon to Parallelize.
  3. Use cargo-flamegraph to find bottlenecks.
  4. Use cargo-asm to identify suboptimal code generation.
  5. Use intrinsics to optimize for specific CPUs.
  6. Use unsafe code to eliminate unnecessary bounds checks.
  7. Use unsafe code to parallelize non-overlapping memory access.

What are GBDTs?

GBDT stands for Gradient Boosted Decision Tree. GBDTs are a type of machine learning model that perform incredibly well on tabular data. Tabular data is data that you would normally find in a spreadsheet or csv.

In order to get a feeling for how GBDT's work, let's go through an example of making a prediction with a single decision tree. Let’s say you want to predict the price of a house based on features like the number of bedrooms, bathrooms, and square footage. Here is a table with 3 features num_bedrooms, num_bathrooms and sqft. The final column called price is what we are trying to predict.

num_bedrooms num_bathrooms sqft price
2 2 1200 $300k
4 3.5 2300 $550k
3 3 2200 $450k
8 9 7000 $990k

To make a prediction with a decision tree, you start at the top of the tree, and at each branch you ask how one of the features compares with a threshold. If the value is less than or equal to the threshold, you go to the left child. If the value is greater, you go to the right child. When you reach a leaf, you have the prediction.

Example Decision Tree

Let’s make an example prediction. We have a house with 3 bedrooms, 3 bathrooms, and 2500 square feet. Let’s see what price our decision tree predicts.

num_bedrooms num_bathrooms sqft price
3 3 2500 ?

Starting at the top, the number of bedrooms is 3 which is less than or equal to 3, so we go left. The square footage is 2500 which is greater than 2400, so we go right, and we arrive at the prediction which is $512K.

Example Prediction Single Decision Tree

A single decision tree isn’t very good at making predictions on its own, so we train a bunch of trees, one at a time, where each tree predicts the error in the sum of the outputs of the trees before it. This is called gradient boosting over decision trees!

Making a prediction with gradient boosted decision trees is easy. We start with a baseline prediction which in the case of regression (predicting a continuous value like the price of a home) is just the average price of houses in our dataset. Then, we run the process we described for getting a prediction out of a single decision tree for each tree and sum up the outputs. In this example, the prediction is $340K.

Gradient Boosted Decision Tree

To learn more about GBDT's check out the wikipedia article on gradient boosting.

Use Rayon to Parallelize

So now that we know a little about GBDT's, let's talk about how we made our code fast. The first thing we did was parallelize our code. Rayon makes this really easy. Rayon is a data parallelism library for Rust that makes converting sequential operations into parallel ones extremely easy.

Matrix

The process of training trees takes in a matrix of training data which is n_rows by n_features.

Matrix with column-wise parallelization

To decide which feature to use at each node, we need to compute a score for each feature. We do this by iterating over each column in the matrix. The following is a sequential iteration over the columns.

dataset.columns().iter().map(|column| {
  // compute the score of branching on this feature
})
Enter fullscreen mode Exit fullscreen mode

We can parallelize this code with Rayon. All we have to do is change the call to iter to par_iter.

dataset.columns().par_iter().map(|column| {
  // compute the score of branching on this feature
})
Enter fullscreen mode Exit fullscreen mode

Rayon will keep a thread pool around and schedule items from your iterator to be processed in parallel. Parallelizing over the features works well when the number of features is larger than the number of cores on your computer. When the number of features is smaller than the number of logical cores in your computer, parallelizing over the features is not as efficient. This is because some of our cores will be sitting idle so we will not be using all of the compute power available to us. You can see this clearly in the image below. Cores 1 through 4 have work to do because they have features 1 through 4 assigned to them. Cores 5 through 8 though are sitting idle.

Core utilization with column-wise parallelization

In this situation, we can parallelize over chunks of rows instead and make sure we have enough chunks so that each core has some work to do.

Matrix with row-wise parallelization

Each core now has some rows assigned to it, and no core is sitting idle.

Core utilization with row-wise parallelization

Distributing the work across rows is super easy with Rayon as well. We just use the combinator par_chunks!

dataset.rows().par_chunks(N).map(|chunk| {
  // process the chunk
});
Enter fullscreen mode Exit fullscreen mode

These are just a couple of the combinators available in Rayon. There are a lot of other high-level combinators that make it easy to express complex parallel computations. Check out Rayon on GitHub to learn more.

Use cargo-flamegraph to find bottlenecks

Next, we used cargo-flamegraph to find where most of the time was being spent. Cargo-flamegraph makes it easy generate flamegraphs and integrates elegantly with cargo. You can install it with cargo install, then run cargo flamegraph to run your program and generate a flamegraph.

sudo apt install -y linux-perf
cargo install flamegraph
cargo flamegraph
Enter fullscreen mode Exit fullscreen mode

Here is a simple example with a program that calls two subroutines, foo and bar.

fn main() {
  foo();
  bar();
}
Enter fullscreen mode Exit fullscreen mode

When we run cargo flamegraph we get an output that looks like this.

Flamegraph

It contains a lot of extra functions that you have to sort through, but it boils down to something like this.

Simplified Flamegraph

The y-axis of the graph is the call stack, and the x axis is duration. The bottom of the graph shows that the entire duration of the program was spent in the main function. Above that, you see that the main function’s time is broken up between calls to foo and bar, and that about two thirds of the time was spent in foo and its subroutines and about one third of the time spent in bar and its subroutines.

In our code for training decision trees, the flamegraph showed one function where the majority of the time was spent. In this function, we maintain an array of the numbers 0 to n that we call indexes, and at each iteration of training we rearrange it. Then, we access an array of the same length, called values, but in the order of the indexes in the indexes array.

let indexes = (0..n).collect();

// rearrange indexes

for index in indexes {
  let mut value = values.get_mut(index);
  // mutate the value
}
Enter fullscreen mode Exit fullscreen mode

This results in accessing each item in the values array out of order. We will refer back to this function throughout the rest of this post.

Use cargo-asm to identify suboptimal code generation

From the flamegraph, we knew which function was taking the majority of the time, which we briefly described above. We started by looking at the assembly code it generated to see if there were any opportunities to make it faster. We did this with cargo-asm.

cargo install cargo-asm
cargo asm --rust path::to::a::function
Enter fullscreen mode Exit fullscreen mode

Like cargo-flamegraph, cargo-asm is really easy to install and integrates nicely with cargo. You can install it with cargo install, and run it as a cargo subcommand.

Here is a simple example with a function that adds two numbers and multiplies the result by two.

pub fn sum_times_two(x: i32, y: i32) -> i32 {
  let sum = x + y;
  let sum_times_two = sum * 2;
  sum_times_two
}
Enter fullscreen mode Exit fullscreen mode

When we run cargo asm, we get an output that looks like this. It shows the assembly instructions alongside the rust code that generated them.

// pub fn sum_times_two(x: i32, y: i32) -> i32 {
// let sum = x + y;
add        edi, esi
// let sum_times_two = sum * 2;
lea        eax, [rdi, +, rdi]
//}
ret
//}
Enter fullscreen mode Exit fullscreen mode

Note that due to all the optimizations the compiler does, there is often not a perfect correlation from the rust code to the assembly.

for index in indexes {
  let mut value = values.get_mut(index);
  // mutate the value
}
Enter fullscreen mode Exit fullscreen mode

When we looked at the assembly for this loop, we were surprised to find an imul instruction, which is an integer multiplication.

imul     rax, qword, ptr, [r8, +, 16]
Enter fullscreen mode Exit fullscreen mode

What is that doing in our code? We are just indexing into an array of f32’s. f32’s are 4 bytes each, so the compiler should be able to get the address of the ith item by multiplying i * 4. Multiplying by four is the same as shifting i left by two. Shifting left by two is much faster than integer multiplication.

for index in indexes {
  let mut value = values.get_mut(index);
  // mutate the value
}
Enter fullscreen mode Exit fullscreen mode

Why would the compiler not be able to produce a shift left instruction? Well, the values array is a column in a matrix and a matrix can be stored either in row major or column major order. This means that indexing into the column might require multiplying by the number of columns in the matrix, which is unknown at compile time. But since we were storing our matrix in column major order, we could eliminate the multiplication, but we have to convince the compiler of this.

let values = values.as_slice();
// for index in indexes {
  // let mut value = valuues.get_mut(index);
  // mutate the value
// }
Enter fullscreen mode Exit fullscreen mode

We did this by casting the values array to a slice. This convinced the compiler that the values array was contiguous, so it could access items using the shift left instruction, instead of integer multiplication.

Use intrinsics to optimize for specific CPUs

Next, we used compiler intrinsics to optimize for specific CPUs. Intrinsics are special functions that hint to the compiler to generate specific assembly code.

for index in indexes {
  let mut value = values.get_mut(index);
  // mutate the value
 }
Enter fullscreen mode Exit fullscreen mode

Values array out of order access

Remember how we noticed that this code results in accessing the values array out of order? This is really bad for cache performance, because CPUs assume you are going to access memory in order. If a value isn’t in cache, the CPU has to wait until it is loaded from main memory, making your program slower.

Cache Hierarchy

However, we know which values we are going to be accessing a few iterations of the loop in the future. We know this because the indexes are given by the indexes array. So 10 iterations in the future, we will be accessing values[indexes[current_index + 10]]. We can hint to x86_64 CPU’s to prefetch those values into cache using the mm_prefetch intrinsic. We experimented with different values of the OFFSET until we got the best performance. If the OFFSET is too small, the CPU will still have to wait for the data, if the OFFSET is too large, data that the CPU needs might be evicted and by the time the CPU gets to the iteration that needs the data, it will no longer be there. The best offset will vary depending on your computer's hardware so it can be more of an art than a science.

for index in indexes {
  mm_prefetch(values.as_ptr().offset(index + OFFSET));
  let value = values[index];
  // do something with value
}
Enter fullscreen mode Exit fullscreen mode

If you are interested in cache performance and writing performant code, check out Mike Acton's talk on Data-Oriented Design and Andrew Kelly's recent talk at Handmade Seattle.

Use unsafe code to eliminate unnecessary bounds checks

Next, we used a touch of unsafe to remove some unnecessary bounds checks.

// for index in indexes {
  let mut value = values.get_mut(index);
  // mutate the value
//}
Enter fullscreen mode Exit fullscreen mode

Most of the time, the compiler can eliminate bounds checks when looping over values in an array. However, in this code, it has to check that index is within the bounds of the values array.

// for index in indexes {
  let mut value = values.get_unchecked_mut(index);
  // mutate the value
// }
Enter fullscreen mode Exit fullscreen mode

But we said at the beginning that the indexes array is a permutation of the values 0 to n. This means the bounds check is unnecessary. We can fix this by replacing get_mut with get_unchecked_mut. We have to use unsafe code here, because Rust provides no way to communicate to the compiler that the values in the indexes array are always in bounds of the values array.

Use unsafe code to parallelize non-overlapping memory access

Finally, we parallelized the following code:

for index in indexes {
  let mut value = values.get_mut(index);
  // mutate the value
}
Enter fullscreen mode Exit fullscreen mode

But is it even possible to parallelize? At first glance, it seems the answer is no, because we are accessing the values array mutably in the body of the loop.

indexes.par_iter().for_each(|index| {
  let mut value = values.get_mut(index);
})
Enter fullscreen mode Exit fullscreen mode

If we try it, the compiler will give us an error indicating overlapping borrows. However, the indexes array is a permutation of the values 0 to n, so we know that the access into the values array is never overlapping.

We can parallelize our code using unsafe Rust, wrapping a pointer to the values in a struct and unsafely marking it as Send and Sync.

struct ValuesPtr(*mut [f32]);
unsafe impl Send for ValuesPtr {}
unsafe impl Sync for ValuesPtr {}
let values = ValuesPtr(values);
indexes.par_iter().for_each(|index| unsafe {
  let mut value = values.0.get_unchecked_mut(index);
})
Enter fullscreen mode Exit fullscreen mode

So, going back to the code we started out with...

for index in indexes {
  let mut value = values.get_mut(index);
  // mutate the value
}
Enter fullscreen mode Exit fullscreen mode

Finally, when combining the four optimizations together:

  1. Making sure that the values array is a contiguous slice.
  2. Prefetching values so they are in cache.
  3. Removing bounds checks because we know the indexes are always in bounds.
  4. Parallelizing over the indexes because we know they never overlap.

this is the code we get:

struct ValuesPtr(*mut [f32]);
unsafe impl Send for ValuesPtr {}
unsafe impl Sync for ValuesPtr {}
let values = ValuesPtr(values.as_slice());
indexes.par_iter().for_each(|index| unsafe {
  mm_prefetch(values.as_ptr().offset(index + OFFSET));
  let mut value = values.0.get_unchecked_mut(index);
  // mutate the value
})
Enter fullscreen mode Exit fullscreen mode

Here are our benchmarks on training time comparing Tangram's Gradient Boosted Decision Tree Library to LightGBM, XGBoost, CatBoost, and sklearn.

Training Time Benchmarks

To see all the benchmarks, head over to https://tangram.dev/benchmarks.

If you are interested in reading the code or giving us a star, the project is available on GitHub.

Top comments (0)