Writing the Fastest GBDT Library in Rust by Isabella Tromba
In this talk, I will share my experience optimizing a Rust implementation of the Gradient Boosted Decision Tree machine learning algorithm. With code snippets, stack traces, and benchmarks, we’ll explore how rayon, perf, cargo-asm, compiler intrinsics, and unsafe rust were used to write a GBDT library that trains faster than similar libraries written in C/C++.
Hi, I'm Isabella. Today we're going to talk about how we wrote the fastest greedy Lusive decision tree library in Rest. First a bit about Me I'm a founder at Tangram, where we build open source tools that make it easy for programmers to train, deploy, and monitor machine learning models. Tangram is written entirely in Rest from the core machine learning algorithms to the back end and front end of the Web application. You can check it out on GitHub or on our website at the links here. So what our gradient boosted decision trees? Let's say you want to predict the price of a house based on features like the number of bedrooms, bathrooms, and square footage. To make a prediction with a decision tree, you start at the top and in each branch, you ask how one of the features compared to the threshold. If the value is less than or equal to the threshold, you go to the left child. If the value is greater, you go to the right child. When you reach a leaf, you have the prediction. Here's an example. We have a house with three bedrooms, three bathrooms, and 2500 sqft. Let's see what the price or decision tree produces. Start at the top. The number of bedrooms is three, which is less than or equal to three. So we go left. The square footage is 2500, which is greater than 2400. So we go right and we arrive at the prediction, which is 512,000. A single decision tree isn't very good at making predictions on its own, so we train a bunch of trees one at a time where each tree predicts the error in the sum of the outputs of the trees before it. This is called gradient boosting over decision trees. In this example, the prediction is 340,000. Now let's talk about how we made our implementation fast. The first thing we did was parallelize our code and Rayon makes this really easy. The process of training trees takes in a matrix of training data, which is n rows by n features. To decide which feature to use in each node, we need a computer score for each feature. We can compute that score for each feature in parallel with Rayon. It's as easy as changing the call to enter to part ITER ran will keep a thread, pull around and schedule items from your Iterator to be processed in parallel. This works well when the number of features is larger than the number of cores on your computer. When the number of features is smaller than the number of cores on your computer, parallelizing over the features is not as efficient. This is because some of the core will be saving idle, so we will not be using all the compute power available to us. Instead, we can parallelize over chunks of rows and make sure we have enough chunks so that each core has some work to do. Each core now has some rows assigned to it and no core is sitting idle. Distributing the work across rows is super easy with Ron as well. We can just use the Combinator par chunks. Rayon has a lot of other high level combinators that make it easy to express complex parallel computations. Next, we used Cargo Flame graph to find where most of the time was being spent. Carbo Flame Graph makes it easy to generate flame graphs and integrates elegantly with Cargo. You can install it with Cargo install and then run Cargo Flame Graph to run your program and generate a flame graph. Here's a simple example with a program that calls two Sep routines Foo and Bar. When we run Cargo Flame graph, we get an output that looks like this. It contains a lot of extra functions that you have to sort through, but it boils down to something like this. The y axis of the graph is the call stack and the X axis is duration. The bottom of the graph shows that the entire duration of the program was spent in the main function. Above that you see that the main functions time is broken up between calls to Foo and Bar, and that about two thirds of the time was spent in fun. It subroutines versus about one third of the time spent in Bar and its subroutines. In our code for training decision trees, the Flame graph showed one function where the majority of the time was spent. He boiled down to something like this. We maintain an array of the numbers zero to n that we call indexes, and at each iteration of training we rearrange it. Then we access an array of the same length called values, but in the order of the indexes in the indexes array, this results in accessing each item in the values array out of order from the Flame graph. We knew which function was taking the majority of time, so we looked at the assembly code generated to see if there were any opportunities to make it faster. We did this with Cargo ASM like Cargo Flame graph. Cargo ASM is really easy to install and integrates nicely with Cargo. You can install it with Cargo install and run it as a Cargo sub command. Here is a simple example with a function that adds two numbers and multiplies the result by two. When we run Cargo as we get an output that looks like this, it shows the assembly instructions alongside the rest code that generated them. Note that due to all the optimizations the compiler does, there's often not a perfect correlation from the rest code to the assembly. When we looked at the assembly for this loop, we were surprised to find a mole instruction, which is an integer multiplication. What is that doing in our code? We're just indexing into an array of F 32 and F 32 are four bytes each, so the compiler should be able to get the address of the Ith item by multiplying by four and it can do this by shifting I left by two, which is much faster than integer multiplication. Well, the values array is a column in a matrix, and a matrix can be stored in either row major or column major order. This means that indexing into the column might require multiplying with a number of columns in the matrix, which is unknown at compile time. But since we're storing our matrix in column major order, we could eliminate the multiplication, but we have to convince the compiler of this. We do this by casting the values array to a slice. This convinced the compiler that the values array was contiguous so it could access items using the Shift's left instruction instead of integer multiplication. Next, we used compiler intrinsic to optimize for specific CPUs intrinsic or special functions that hint to the compiler to generate specific assembly code. Remember how we notice that this code results in accessing the values array out of order. This is really bad for cash performance because CPUs assume you're going to access memory in order. If a value isn't in cash, the CPU has to wait until it's loaded from main memory, making your program slower. However, we know which values we're going to be accessing a few iterations in the loop. In the future, we could hint to X 8664 CPUs to prefetch those values into cash. Using the Mm prefetch intrinsic, we experimented with different values of the offset until we got the best performance. Next, we used a touch of unsafe to remove some unnecessary bounce checks. Most of the time, compiler can eliminate bounce checks when looping over values in an array. However, in this code it has to check that index is within the bounds of the values arrayed. But as we said in the beginning, the index array is just a permutation of the value zero to n, which means the bounce checks are unnecessary. We can fix this by replacing get mute with get unchecked mute. We have to use unsafe code here because Rust provides no way to communicate to the compiler that the values in the indexes array are always in bounds of the values array. Finally, we paralyze that section of code, but is it even possible to paralyze? At first glance, it seems like the answer is no because we are accessing the values Arrain mutable in the body of the loop. If we try it, the compiler will give us an error indicating overlapping borrows. However, the indexes array is a permutation of the values zero to n, so we know that the access to the values array is never overlapping. We can parallelize our code using unsafe Rust, wrapping a pointer to the values in obstruct and unsafely, marking it ascend and sync. So going back to the code we started out with combining the four optimizations together, making sure that the values array is a contiguous slice prefetching value, so they're in cash removing bounce checks because we know that the indexes are always in bounds and paralyzing over the indexes because we know they never overlap. This is the code we get and the results are great. Tangrams Gradient Lusive Decision Tree Library is faster than leading open source alternatives. Thank you so much for listening. If you're interested in learning more, please check out Tangram on GitHub at GitHub. Com. Tangram Dev Tangram and if you like the project, please give it a star. Thank you.