As Elixir's Machine Learning (ML) ecosystem grows, many Elixir enthusiasts who wish to adopt the new machine learning libraries in their projects are stuck at a crossroads of wanting to move away from their existing ML stack (typically Python) while not having a clear path of how to do so. I would like to take some time to talk to WHY I believe now is a good time to start porting over Machine Learning code into Elixir, and HOW I went about doing just this for two libraries I wrote: EXGBoost (from Python XGBoost) and Mockingjay (from Python Hummingbird).
Why is Python not Sufficient?
There's a common saying in programming languages that no language is perfect, but that different languages are suited for different jobs. Languages such as C, Rust, and now even Zig are known for their targeting systems development, while languages such as C++, C#, and Java are more commonly used for application development, and obviously there are the web languages such as JavaScript/TypeScript, PHP, Ruby (on Rails), and more. There are gradations to these rules of course, but more often than not there are good reasons that languages tend to exist within the confines of particular use cases.
Languages such as Elixir and Go tend to be used in large distributed systems because they place an emphasis on having great support for common concurrency patterns, which can come at the cost of supporting other domains. Go, for example, has barely (if any?) support for machine learning libraries, but it's also not trying to cater to that as a target domain. For a long time,e the same could have been said about Elixir, but over the past two or so years, there has been a massive concerted push from the Elixir community to not only have support for machine learning, but to push the envelope with the maintaining state of the art libraries that are beginning to compete with the other dominant machine learning languages - namely Python.
Python has long been the gold standard in the realm of machine learning. The breadth of libraries and the low entry barrier makes Python a great language to work with, but it does create a bit of a bottleneck. Any application that wishes to integrate machine learning has historically had only a couple of options: have a Python component or reach into the underlying libraries that power much of the Python libraries directly. Despite all the good parts of Python I mentioned before, speed and support for concurrency are not on that list. Elixir-Nx is striving to give another option - an option that can take advantage of the native distributed support that Elixir and the BEAM VM have to offer. Nx's Nx.Serving
construct is a drop-in solution for serving distributed machine-learning models.
How to Proceed
Sean Moriarity, the co-creator of Nx, creator of Axon, and author of Machine Learning in Elixir, has talked many times about how the initial creation of Nx and Axon involved hours upon hours of reading source code from reference implementations of libraries in Python and C++, namely the Tensorflow source code. While I was writing EXGBoost and Mockingjay, much of my time, especially towards the beginning, was spent referencing the Python and C++ implementations of the original libraries. This builds a great fundamental understanding of the libraries as well as taught me how to identify patterns in Python and C++ and identify the Elixir pattern that could express the same ideas. This skill is invaluable, and the better I got at it the faster I could write. Below is a summary and key takeaways from my process of porting Python / PyTorch to Elixir / Nx.
Workflow Overview
Before I get to the examples from the code bases, I would like to briefly explain the high-level cyclical workflow I established while working on this effort, and what I would recommend to anyone pursuing a similar endeavor.
Understand the Macro System
Much like how there's a common strategy to reading comprehension which involves reading through the entire document once to get a high-level understanding and then doing subsequent shorter reads to gain more in-depth understanding with the added context of the entire piece, you can consider doing the same when reading code. My first step was to follow the logical flow from the call of hummingbird.ml.convert
to the final result. You can use tools such as function tracers and callgraph generators to accelerate this part of the process, or manually trace depending on the extent of the codebase. I felt in my case that it was manageable to trace myself.
Read the Documentation
Once you have a general understanding of the flow and process of the original system, you can start referring to the documentation for some additional context. In my case, this lead me to the academic paper Taming Model Serving Complexity, Performance and Cost: A Compilation to Tensor Computations Approach, which was the underlying ground work and basis for their implementation. I could write a whole other blog post about the process of transcribing algorithms and code from academic papers and pseudocode, but for now just know that these are some of the most important pieces you can refer to while re-implementing or porting over a piece of source code.
Read the Source Code in Detail
This is the point in which you want to disambiguate the higher-level ideas from the first step and really gain a fine, high-resolution understanding of what is happening. There might even be some points in which you need to deconflict the source code with its documentation and/or paper reference. In those cases, the source code almost always wins, and if not, then you likely have a bug report you can file. If you see things you don't fully understand, you don't necessarily need to address it here, but you should make note of it and keep it in mind while working in case new details help resolve it.
Implement the New Code
At this point, you should feel comfortable enough to start implementing the code. I found this to be a very iterative process, meaning I would think I had a grasp on something, then would start working on implementing it, then would realize I did not understand it as well as I had thought and would work my way back through the previous steps.
Example
💡
In case you would like to follow along going forward, the Python code I will be referencing is the Microsoft Hummingbird source code (specifically their implementation of Decision Tree Compilation), and the Elixir code is from the Mockingjay source code.
Class vs. Behaviour
As a result of the reading and comprehension I did of the Hummingbird code base, I realized fairly early on that my library was going to have some key differences. One of the main reasons for these differences was the fact that the Hummingbird code base was built as a retroactive library that needed to cater to existing APIs that existed throughout the Python ecosystem. They chose to only add support for converting decision trees according to the SKLearn API. I, conversely, chose to write Mockingjay in such a way that it would be incumbent upon the authors of decision tree libraries to implement a protocol to interface with Mockingjay's convert
function. This difference meant that I could establish a Mockingjay.Tree
data structure that I would use throughout my library, rather than having to reconstruct tree features from various other APIs as is done in Hummingbird.
Next, Hummingbird approaches its pipeline in a very-object oriented manner, as makes sense when using Python. Here' we are focusing on the implementation of the three decision tree conversion strategies: GEMM, Tree Traversal, and PErfect Tree Traversal. It implements the following base class for tree conversions as well as Pytorch networks.
💡
Since they're inheriting from torch.nn.model
they must also implement the forward
method.
class AbstracTreeImpl(PhysicalOperator):
"""
Abstract class definig the basic structure for tree-base models.
"""
def __init__ (self, logical_operator, **kwargs):
super(). __init__ (logical_operator, **kwargs)
@abstractmethod
def aggregation(self, x):
"""
Method defining the aggregation operation to execute after the model is evaluated.
Args:
x: An input tensor
Returns:
The tensor result of the aggregation
"""
pass
class AbstractPyTorchTreeImpl(AbstracTreeImpl, torch.nn.Module):
"""
Abstract class definig the basic structure for tree-base models implemented in PyTorch.
"""
def __init__ (
self, logical_operator, tree_parameters, n_features, classes, n_classes, decision_cond="<=", extra_config={}, **kwargs
):
"""
Args:
tree_parameters: The parameters defining the tree structure
n_features: The number of features input to the model
classes: The classes used for classification. None if implementing a regression model
n_classes: The total number of used classes
decision_cond: The condition of the decision nodes in the x <cond> threshold order. Default '<='. Values can be <=, <, >=, >
"""
super(AbstractPyTorchTreeImpl, self). __init__ (logical_operator, **kwargs)
They then proceed to inherit from these base classes and have different classes for each of the three decision tree strategies as well as their gradient-boosted counterparts, leaving them with three classes for each strategies (1 base class per strategy, 1 for ensemble implementations, and 1 for normal impementations) and nine total classes.
I chose to approach this using a behaviour
defmodule Mockingjay.Strategy do
@moduledoc false
@type t :: Nx.Container.t()
@callback init(data :: any(), opts :: Keyword.t()) :: term()
@callback forward(x :: Nx.Container.t(), term()) :: Nx.Tensor.t()
...
end
forward
will perform setup functionality depending on the strategy and return the parameters that will need to be passed to forward
later on. This allows for a very simple top-level api. The whole top-level mockingjay.ex
file can fit here:
def convert(data, opts \\ []) do
{strategy, opts} = Keyword.pop(opts, :strategy, :auto)
strategy =
case strategy do
:gemm ->
Mockingjay.Strategies.GEMM
:tree_traversal ->
Mockingjay.Strategies.TreeTraversal
:perfect_tree_traversal ->
Mockingjay.Strategies.PerfectTreeTraversal
:auto ->
Mockingjay.Strategy.get_strategy(data, opts)
_ ->
raise ArgumentError,
"strategy must be one of :gemm, :tree_traversal, :perfect_tree_traversal, or :auto"
end
{post_transform, opts} = Keyword.pop(opts, :post_transform, nil)
state = strategy.init(data, opts)
fn data ->
result = strategy.forward(data, state)
{_, n_trees, n_classes} = Nx.shape(result)
result
|> aggregate(n_trees, n_classes)
|> post_transform(post_transform, n_classes)
end
end
As you can see, the use of a behaviour here allows a strategy-agnostic approach to generating a prediction pipeline. In the object-oriented implementation, each class implements init
, forward
, aggregate
, and post_transform
. We get the same result from a functional pipeline approach, where each step generates the needed information as input parameters for the next step. So, instead of storing intermediate results as object properties or values in an object's __dict__
, we just pass them along in the pipeline. I would argue this creates a much simpler and easier to follow implementation (but I am also quite biased).
PyTorch to Nx
For these examples, we will be looking at porting the implementations of the forward
function for the three conversion strategies from Python to Nx.
GEMM
Next, let's look at the forward
function implementation for GEMM, one of the three conversion strategies. In Hummingbird, they implemented the forward
step in the base class for each strategy. So given three GEMM classes with the signatures of GEMMTreeImpl(AbstractPyTorchTreeImpl)
, GEMMDecisionTreeImpl(GEMMTreeImpl)
, and GEMMGBDTImpl(GEMMTreeImpl)
, the forward
function is defined in the GEMMTreeImpl
class, since both ensemble and non-ensemble decision tree models share the same forward step.
def forward(self, x):
x = x.t()
x = self.decision_cond(torch.mm(self.weight_1, x), self.bias_1)
x = x.view(self.n_trees, self.hidden_one_size, -1)
x = x.float()
x = torch.matmul(self.weight_2, x)
x = x.view(self.n_trees * self.hidden_two_size, -1) == self.bias_2
x = x.view(self.n_trees, self.hidden_two_size, -1)
if self.tree_op_precision_dtype == "float32":
x = x.float()
else:
x = x.double()
x = torch.matmul(self.weight_3, x)
x = x.view(self.n_trees, self.hidden_three_size, -1)
Now, here is the Nx implementation:
@impl true
deftransform forward(x, {arg, opts}) do
opts =
Keyword.validate!(opts, [
:condition,
:n_trees,
:n_classes,
:max_decision_nodes,
:max_leaf_nodes,
:n_weak_learner_classes,
:custom_forward
])
_forward(x, arg, opts)
end
defnp _forward(x, arg, opts \\ []) do
%{mat_A: mat_A, mat_B: mat_B, mat_C: mat_C, mat_D: mat_D, mat_E: mat_E} = arg
condition = opts[:condition]
n_trees = opts[:n_trees]
n_classes = opts[:n_classes]
max_decision_nodes = opts[:max_decision_nodes]
max_leaf_nodes = opts[:max_leaf_nodes]
n_weak_learner_classes = opts[:n_weak_learner_classes]
mat_A
|> Nx.dot([1], x, [1])
|> condition.(mat_B)
|> Nx.reshape({n_trees, max_decision_nodes, :auto})
|> then(&Nx.dot(mat_C, [2], [0], &1, [1], [0]))
|> Nx.reshape({n_trees * max_leaf_nodes, :auto})
|> Nx.equal(mat_D)
|> Nx.reshape({n_trees, max_leaf_nodes, :auto})
|> then(&Nx.dot(mat_E, [2], [0], &1, [1], [0]))
|> Nx.reshape({n_trees, n_weak_learner_classes, :auto})
|> Nx.transpose()
|> Nx.reshape({:auto, n_trees, n_classes})
end
Do not be distracted by the length of this code snippet, as much of the lines are taken up by validating arguments. Let's look at a more stripped-down version without that:
@impl true
deftransform forward(x, {arg, opts}) do
_forward(x, arg, opts)
end
defnp _forward(x, arg, opts \\ []) do
mat_A
|> Nx.dot([1], x, [1])
|> condition.(mat_B)
|> Nx.reshape({n_trees, max_decision_nodes, :auto})
|> then(&Nx.dot(mat_C, [2], [0], &1, [1], [0]))
|> Nx.reshape({n_trees * max_leaf_nodes, :auto})
|> Nx.equal(mat_D)
|> Nx.reshape({n_trees, max_leaf_nodes, :auto})
|> then(&Nx.dot(mat_E, [2], [0], &1, [1], [0]))
|> Nx.reshape({n_trees, n_weak_learner_classes, :auto})
|> Nx.transpose()
|> Nx.reshape({:auto, n_trees, n_classes})
end
Let's take a look at some obvious difference:
- The
Nx
code does not have to transpose in the first step sinceNx.dot/4
allows you to specify the contracting axes. - You can use
Nx.dot/6
to get the same behavior astorch.matmul
-
torch.matmul
does a lot of wizardry with broadcasting to make this instance work
-
- We use functions such as
Nx.equal
to fit into the pipeline rather than using the==
oeprator (which would work outside of a pipeline) -
torch.view
is equivalent toNx.reshape
-
Nx
uses the:auto
atom to wheretorch
uses-1
to reference infering the sie of an axis
Outside of these differences, the code translates fairly easily. Let's take a look at a bit of a more complex instance.
Tree Traversal
Here is the Python implementation:
def _expand_indexes(self, batch_size):
indexes = self.nodes_offset
indexes = indexes.expand(batch_size, self.num_trees)
return indexes.reshape(-1)
def forward(self, x):
indexes = self.nodes_offset
indexes = indexes.expand(batch_size, self.num_trees).reshape(-1)
for _ in range(self.max_tree_depth):
tree_nodes = indexes
feature_nodes = torch.index_select(self.features, 0, tree_nodes).view(-1, self.num_trees)
feature_values = torch.gather(x, 1, feature_nodes)
thresholds = torch.index_select(self.thresholds, 0, indexes).view(-1, self.num_trees)
lefts = torch.index_select(self.lefts, 0, indexes).view(-1, self.num_trees)
rights = torch.index_select(self.rights, 0, indexes).view(-1, self.num_trees)
indexes = torch.where(self.decision_cond(feature_values, thresholds), lefts, rights).long()
indexes = indexes + self.nodes_offset
indexes = indexes.view(-1)
output = torch.index_select(self.values, 0, indexes).view(-1, self.num_trees, self.n_classes)
And here is the Nx implementation:
defn _forward(x, features, lefts, rights, thresholds, nodes_offset, values, opts \\ []) do
max_tree_depth = opts[:max_tree_depth]
num_trees = opts[:num_trees]
n_classes = opts[:n_classes]
condition = opts[:condition]
unroll = opts[:unroll]
batch_size = Nx.axis_size(x, 0)
indices =
nodes_offset
|> Nx.broadcast({batch_size, num_trees})
|> Nx.reshape({:auto})
{indices, _} =
while {tree_nodes = indices, {features, lefts, rights, thresholds, nodes_offset, x}},
_ <- 1..max_tree_depth,
unroll: unroll do
feature_nodes = Nx.take(features, tree_nodes) |> Nx.reshape({:auto, num_trees})
feature_values = Nx.take_along_axis(x, feature_nodes, axis: 1)
local_thresholds = Nx.take(thresholds, tree_nodes) |> Nx.reshape({:auto, num_trees})
local_lefts = Nx.take(lefts, tree_nodes) |> Nx.reshape({:auto, num_trees})
local_rights = Nx.take(rights, tree_nodes) |> Nx.reshape({:auto, num_trees})
result =
Nx.select(
condition.(feature_values, local_thresholds),
local_lefts,
local_rights
)
|> Nx.add(nodes_offset)
|> Nx.reshape({:auto})
{result, {features, lefts, rights, thresholds, nodes_offset, x}}
end
values
|> Nx.take(indices)
|> Nx.reshape({:auto, num_trees, n_classes})
end
Here there are some much more striking differences, namely the use of Nx
's while
expression compared to a for
loop in Python. We use while
in this case since it can achieve the same purpose as the Python for
loop and it is supported by Nx
within a defn
expression. Otherwise, we might have to perform some of the calculations within a deftransform
, as we will see in the next example. Another obvious difference is that in the Nx implementation, we have to pass the required variables around throughout these operation, whereas Python can use stored class attributes.
Still, the conversion is quite straightforward. I hope you are beginning to see that this is not an impossible effort, and can be accomplished given you have a firm understanding of the source material.
Perfect Tree Traversal
Lastly, let's look at the last conversion strategy. Yet again, this conversion is even slightly more complex, but hopefully seeing this example will help you in your case:
def forward(self, x):
prev_indices = (self.decision_cond(torch.index_select(x, 1, self.root_nodes), self.root_biases)).long()
prev_indices = prev_indices + self.tree_indices
prev_indices = prev_indices.view(-1)
factor = 2
for nodes, biases in zip(self.nodes, self.biases):
gather_indices = torch.index_select(nodes, 0, prev_indices).view(-1, self.num_trees)
features = torch.gather(x, 1, gather_indices).view(-1)
prev_indices = (
factor * prev_indices + self.decision_cond(features, torch.index_select(biases, 0, prev_indices)).long()
)
output = torch.index_select(self.leaf_nodes, 0, prev_indices).view(-1, self.num_trees, self.n_classes)
And the Elixir implementation:
defnp _forward(
x,
root_features,
root_thresholds,
features,
thresholds,
values,
indices,
opts \\ []
) do
prev_indices =
x
|> Nx.take(root_features, axis: 1)
|> opts[:condition].(root_thresholds)
|> Nx.add(indices)
|> Nx.reshape({:auto})
|> forward_reduce_features(x, features, thresholds, opts)
Nx.take(values, prev_indices)
|> Nx.reshape({:auto, opts[:num_trees], opts[:n_classes]})
end
deftransformp forward_reduce_features(prev_indices, x, features, thresholds, opts \\ []) do
Enum.zip_reduce(
Tuple.to_list(features),
Tuple.to_list(thresholds),
prev_indices,
fn nodes, biases, acc ->
gather_indices = nodes |> Nx.take(acc) |> Nx.reshape({:auto, opts[:num_trees]})
features = Nx.take_along_axis(x, gather_indices, axis: 1) |> Nx.reshape({:auto})
acc
|> Nx.multiply(@factor)
|> Nx.add(opts[:condition].(features, Nx.take(biases, acc)))
end
)
end
You can see that in this case, we have a function defined in a deftransform
within our forward
pipeline. Why is this so? Well, when writing definitions within defn
you forfeit the use of the default Elixir kernel for the Nx.Kernel
module. If you want full access to all of the normal Elixir modules, you need to use a deftransform
. We needed to use Enum.zip_reduce
in this instance (rather than Nx
's while
like before) since the features
and thresholds
lists are not of uniform shape. Their shape represents the length of a given depth of a binary tree, so they will be a nested list of lengths [1,2,4,8...]
. This is an optimization as opposed to normal TreeTraversal
, but required a bit of a different approach as opposed to the Python implementation which took advantage of torch.nn.ParameterList
to build out the same lists. You might also notice the use of Tuple.to_list
on lines 25 and 26. This was required since we needed features
and thresholds
to be stored in Nx.container
's when passed into the deftransform
, and Tuple
implements the Nx.Container
protocol, while lists do not. Even still, given that knowledge of the intricacies of defn
and deftransform
, the final ported solution is very similar to the reference solution.
Conclusion
In this post, I tried to accomplish several things at once, and perhaps that lead to a cluttered article, but I felt the need to address all of these points at once. I do not mean to suggest that Machine Learning has no place in Python or that Python will not continue to be the most dominant player in Machine Learning, but that I think some healthy competition is a good thing, and that perhaps Python does have some shortcomings that might give other languages valid reasons to coexist in the space.
Next, I wanted to address some specifics as to what Elixir has to offer to the machine learning space. I think it is uniquely positioned to be quite competitive considering the large community push to support more and more libraries, as well as the large application development community that can benefit from an in-house solution.
Lastly, I wanted to share some practical tips for those looking to move on from Python to Elixir, but feeling somewhat helpless in the process. I think that Sean Moriarity's book that I mentioned at the beginning of this article is an invaluable resource and great step in the education of machine learning for Elixir developers, but it can nonetheless feel daunting to seemingly throw out existing working solutions for new-fangled, perhaps not as well respected solutions. I hope I showed how anybody can approach this problem, and any existing Elixir developer can be a machine learning developer going forward. The ground work has been laid, and the tools are available. Thank you for reading (especially if you made it to the end)!
Top comments (0)