DEV Community

Kurt
Kurt

Posted on • Originally published at getcode.substack.com on

From Basic to Fancy Indexing

This post describes the myriad ways of indexing tensors in libraries like NumPy and PyTorch. We start from the basics: indexing with integers, identical to multi-dimensional array indexing in everyday programming languages. We add slicing, a terse language to select regular subsets of a tensor without copying its underlying data. Finally, advanced or fancy indexing is a Numpy feature that allows indexing tensors with other tensors. All these can be combined to slice and dice tensors in every way imaginable.

orange tomatoes near sliced yellow bell pepper, broccoli on wooden chopping board
You’ll be slicing, dicing and chopping tensors like these veggies in no time. Photo by Sanket Shah on Unsplash

To explain, I'll give examples using Tensorken, the tensor library I'm developing in Rust. All code here works as of the v0.4 tag. I'll also dive into implementation details after the indexing semantics are clear.

Don't worry if the examples don't make sense yet - this is just to whet your appetite!

>>> let t = TrI::new(&[4, 3, 2], &(1..25).collect::<Vec<_>>())
                                                     
                                             
  1   2     7    8      13   14     19   20  
  3   4     9    10     15   16     21   22  
  5   6     11   12     17   18     23   24  
                                             
                                                     

// Slicing allows taking regular subsets of a tensor
>>> t.vix3(1..3, 2.., ..)
                       
 [ 11  12]   [ 17  18] 
                       

// Fancy indexing allows indexing a tensor with another int tensor
>>> let i = TrI::new(&[2], &[2, 1])
[ 2  1]

>>> t.vix3(&i, &i, 1)
[ 18  10]

// Masking allows indexing a tensor with another bool tensor
>>> let b = TrB::new(&[4], &[false, false, true, false])
[ false  false  true  false]

>>> t.vix3(&b, &i, 1..)
               
 [ 18]   [ 16] 
               
Enter fullscreen mode Exit fullscreen mode

Previously in Tensors From Scratch: tensor basics, GPU acceleration, and automatic differentiation

This post is the fourth in the Tensors from Scratch series. I try to make each as self-contained as possible, but if you have the time I recommend reading the first post in the series. It lays the necessary groundwork for this one. I'll recap the salient parts throughout. Here's an overview:

  1. Fun and Hackable Tensors in Rust, From Scratch: Basic implementation of tensors on the CPU. Explains concepts like shape broadcasting, and describes essential tensor operations.
  2. Massively Parallel Fun with GPUs: Accelerating Tensors in Rust: An implementation of the essential tensor operations from part 1, but this time on the GPU. Read this if you're interested in GPU computation and how it's different from working with the CPU.
  3. Beyond Backpropagation - Higher Order, Forward and Reverse-mode Automatic Differentiation for Tensorken: Adding automatic differentiation to Tensorken, in a particularly flexible way that allows arbitrary combinations of forward and reverse AD up to any order. Read this if you're interested in how AD works.

Basic indexing

Let's start with the basics: indexing with a single integer index per dimension, exactly like n-dimensional array indexing in nearly all programming languages. Let's use the following 3-dimensional tensor as a running example:

>>> let t = TrI::new(&[4, 3, 2], &(1..25).collect::<Vec<_>>())
                                                     
                                             
  1   2     7    8      13   14     19   20  
  3   4     9    10     15   16     21   22  
  5   6     11   12     17   18     23   24  
                                             
  
Enter fullscreen mode Exit fullscreen mode

The tensor t has shape [4, 3, 2]. We'll see the effect on the shape and result of each indexing example.

If we provide a positive integer index for all dimensions, we get the value back at that index:

>>> t.ix3(0, 1, 2).to_scalar()
[4, 3, 2] -> []
4
Enter fullscreen mode Exit fullscreen mode

In the examples, the first line always shows how the input shape of the tensor t has changed. The following lines show the result. For technical reasons, Tensorken does not use actual indexing notation in Rust, but various functions. ix is one such function. Since Rust does not support variadic arguments, Tensorken's indexing functions are post-fixed with the number of arguments they take: ix1, ix2, and so on.

At the risk of being obvious, it's worth stating what has happened here: an integer index selects a row in the index's corresponding dimension, and the dimension is removed from the tensor.

I use the term row loosely: in a two-dimensional tensor, it could be either a row or a column. In more dimensions, it is a tensor with one fewer dimension than the original tensor.

We have three integer indexes here, one for each dimension, so the result is a scalar with (somewhat artificially) shape [], and the value is the element of t at the "coordinate" (0, 1, 2).

And that's where some programming languages stop.

The first novelty of tensor libraries is that we don't have to provide an index for every dimension:

>>> t.ix1(1)
[4, 3, 2] -> [3, 2]
         
 7    8  
 9    10 
 11   12 
         
Enter fullscreen mode Exit fullscreen mode

That selects index 1 in the first dimension and leaves the other dimensions unchanged. The result is a tensor with shape [3, 2]. We see that an integer index removes the dimension it selects from - we went from three to two dimensions.

A second novelty is that we don't need to index by counting from the start. Sometimes it's handy to start counting from the end of the dimension. Compare:

>>> t.ix3(t.shape()[0] - 1, t.shape()[1] - 1, t.shape()[2] - 1).to_scalar()
[4, 3, 2] -> []
24
>>> t.ix3(tl(0), tl(0), tl(0)).to_scalar()
[4, 3, 2] -> []
24
Enter fullscreen mode Exit fullscreen mode

You may recognize the arr[arr.length - 1] pattern to get the last element of an array. The second line uses tl(0) instead, where tl stands for tail: tl(0) is like 0 but starts counting at the end of the dimension. So tl(0) selects the last element in a dimension, tl(1) the second-to-last, and so on. In fact, 0 is shorthand for hd(0) which stands for head. I could have written the first example as t.ix3(hd(0), hd(1), hd(2)).

In Python (and thus Numpy and PyTorch), tl(0), tl(1) are written as -1 and -2, which is shorter but has a displeasing symmetry: the first element is 0, the last element is -1.  Essentially counting from the start of the dimension is zero-based, and counting from the end is one-based. I'm biased, but in Python, I remember times when I got confused by the asymmetry. (See also https://github.com/rust-ndarray/ndarray/issues/435.)

On to the next feature: we can leave dimensions unchanged by using .. (or : in Python). The last example was shorthand for t.ix3(1, .., ..). We can now index any subset of dimensions:

>>> t.ix3(.., .., 2)
[4, 3, 2] -> [4, 3]
              ┐
 2    4    6  
 8    10   12 
 14   16   18 
 20   22   24 
              ┘
Enter fullscreen mode Exit fullscreen mode

Which has shape [4, 3] because the first two dimensions are left unchanged, and the last dimension is removed by the index.

Additionally, there's a shorthand for consecutive ..: Ellipsis. In Python, this is represented by ..., but Rust doesn't have that symbol. In the last two examples, we could also have written t.ix2(1, Ellipsis) and t.ix2(Ellipsis, 2).

An Ellipsis leaves all dimensions covered by the ellipsis unchanged. An ellipsis can cover zero or more dimensions, and there can be at most one in a given index operation. Otherwise, it would be ambiguous: in t.ix3(Ellipsis, 5, Ellipsis) it's unclear which dimension 5 applies to. In that case, you have to use ... Besides being a bit easier to type, ellipses have the advantage that they make your tensor program more resilient to shape changes.

So far, we can only select an entire dimension or a single row in the dimension. Via slicing, we can also express taking a "rectangular" (in n dimensions) subset:

>>> t.ix3(1..3, 1..2, ..)
[4, 3, 2] -> [2, 1, 2]
                      ┐
 [ 9  10]   [ 15  16] 
                      ┘
Enter fullscreen mode Exit fullscreen mode

The .. ranges we saw earlier were slices. You can also give a lower bound (inclusive) and an upper bound (exclusive) to select just those rows. Such a range index selects a contiguous subset of the rows in its dimension.

A range index never removes dimensions, which means the only difference between a range index 1..2 and the integer index 1 is that the latter additionally removes the dimension of size 1.

You can omit the start or end from a range. If you do, they'll default to 0 and the dimension's size respectively:

>>> t.ix3(..3, 1.., ..)
[4, 3, 2] -> [3, 2, 2]
                                       
                                 
  3   4     9    10     15   16  
  5   6     11   12     17   18  
                                 
                                       
Enter fullscreen mode Exit fullscreen mode

You can also use the hd and tl functions to indicate whether you're counting from the start or the end:

>>> t.ix3(..tl(0), ..tl(1), ..)
[4, 3, 2] -> [4, 2, 2]
                                                    ┐
                   ┐                         
  1   2     7   8      13   14     19   20  
  3   4     9   10     15   16     21   22  
                   ┘                         
                                                    ┘
Enter fullscreen mode Exit fullscreen mode

A final index we can use is NewAxis. NewAxis inserts a new dimension with size 1:

>>> t.ix4(.., .., NewAxis, ..)
[4, 3, 2] -> [4, 3, 1, 2]
                                   
 [ 1  2]     [ 3  4]     [ 5  6]   
 [ 7  8]     [ 9  10]    [ 11  12] 
 [ 13  14]   [ 15  16]   [ 17  18] 
 [ 19  20]   [ 21  22]   [ 23  24] 
                                   
Enter fullscreen mode Exit fullscreen mode

We can use as many NewAxis as we like.

Of course, we can combine all these features:

>>> t.ix4(0, 1.., NewAxis, ..tl(0))
[4, 3, 2] -> [2, 1, 1]
             
 [ 3]   [ 5] 
             
Enter fullscreen mode Exit fullscreen mode

That concludes the tour of basic indexing. Basic indexing is relatively intuitive, albeit with some off-by-one madness in the mix, especially with inclusive-exclusive ranges and tl-based indexing.

The happy property of basic indexing is that it never needs to copy the tensor. Only a small section of memory containing the shape and strides needs to be updated, and the much larger buffer that contains the numbers is shared among the views you create with basic indexing.

The real brain-fuckery begins with so-called fancy indexing.

Fancy indexing

Fancy (or advanced) indexing allows you to use int or bool tensors as indexes.

Indexing with int tensors

Let's start with a one-dimensional int tensor as an index:

>>> let i = TrI::new(&[2], &[1, 2])
[ 1  2]

>>> t.oix1(&i)
[4, 3, 2] -> [2, 3, 2]
                           
                       
  7    8      13   14  
  9    10     15   16  
  11   12     17   18  
                       
                           
Enter fullscreen mode Exit fullscreen mode

To enable fancy indexing in addition to basic indexing, Tensorken makes you use the oix function, which stands for outer indexing. Because fancy indexing always copies the indexed tensor, and basic indexing never copies, it made sense to make the difference apparent in the API. (Python does not do this. It has more important things to worry about, I suppose.)

In this example, we can see that the elements of the indexing tensor i are interpreted as indexes in the first dimension of the indexed tensor t. It's like slicing, except we're not limited to taking contiguous sets of elements. Since i has size 2, the first dimension of the result r also has size 2, and indexes 1 and 2 are selected.

This kind of indexing is like selection and permutation - we can change the order of the elements of t:

>>> let i = TrI::new(&[4], &[3, 0, 2, 1])
[ 3  0  2  1]

>>> t.oix1(&i)
[4, 3, 2] -> [4, 3, 2]
                                                     
                                             
  19   20     1   2     13   14     7    8   
  21   22     3   4     15   16     9    10  
  23   24     5   6     17   18     11   12  
                                             
                                                     
Enter fullscreen mode Exit fullscreen mode

The elements of i do not need to be unique - we can duplicate elements of t:

>>> let i = TrI::new(&[4], &[0, 0, 1, 1])
[ 0  0  1  1]

>>> t.oix1(&i)
[4, 3, 2] -> [4, 3, 2]
                                                   
                                           
  1   2     1   2     7    8      7    8   
  3   4     3   4     9    10     9    10  
  5   6     5   6     11   12     11   12  
                                           
                                                   
Enter fullscreen mode Exit fullscreen mode

Or increase the size of the indexed dimension:

>>> let i = TrI::new(&[5], &[1; 5])
[ 1  1  1  1  1]

>>> t.oix1(&i)
[4, 3, 2] -> [5, 3, 2]
                                                                     
                                                           
  7    8      7    8      7    8      7    8      7    8   
  9    10     9    10     9    10     9    10     9    10  
  11   12     11   12     11   12     11   12     11   12  
                                                           
                                                                     
Enter fullscreen mode Exit fullscreen mode

We're not restricted to indexing with a 1-dimensional indexing tensor. Re-arranging things a bit:

>>> let i = TrI::new(&[2, 2], &[0, 1, 1, 0])
       
 0   1 
 1   0 
       

>>> t.oix1(&i)
[4, 3, 2] -> [2, 2, 3, 2]
                           
                       
  1   2       7    8   
  3   4       9    10  
  5   6       11   12  
                       
                       
  7    8      1   2    
  9    10     3   4    
  11   12     5   6    
                       
                           
Enter fullscreen mode Exit fullscreen mode

In terms of the shape, the indexing tensor's shape replaces the indexed dimension, so we can increase the number of dimensions of t.

Finally, we can combine fancy indexing with slicing to rearrange any dimension:

>>> let i = TrI::new(&[2, 2], &[0, 1, 1, 0])
       
 0   1 
 1   0 
       

>>> t.oix3(.., .., &i)
[4, 3, 2] -> [4, 3, 2, 2]
                                         
                                   
  1   2       3   4       5   6    
  2   1       4   3       6   5    
                                   
                                   
  7   8       9    10     11   12  
  8   7       10   9      12   11  
                                   
                                   
  13   14     15   16     17   18  
  14   13     16   15     18   17  
                                   
                                   
  19   20     21   22     23   24  
  20   19     22   21     24   23  
                                   
                                         
Enter fullscreen mode Exit fullscreen mode

To summarize our observations so far:

  • Based on the position of the indexing tensor i in the index expression, it is matched with a dimension in the indexed tensor t in exactly the same way as any other index expression in basic indexing.
  • The positive integer values in the indexing tensor tensor are interpreted as indexes in the indexed tensor's corresponding dimension. These indexes select rows in that dimension.
  • The indexing tensor i can have an arbitrary shape, and this shape replaces the indexed tensor t's dimension.

Int tensor indexes create considerably more expressive power. Where basic indexing only reduces the size of the tensor and only selects regular parts of it, now we can almost arbitrarily extend or rearrange the tensor.

Multiple fancy indexes

So far, we've only used a single int tensor as an index. What happens when we use multiple fancy indexers in the same indexing expression?

This is where things get mind-blowing. Tensorken implements two distinct ways of handling such cases, called outer indexing via oix and vectorized indexing via vix. The latter is the more powerful. Vectorized indexing can express everything outer indexing can and more. But it is also harder to understand, and the use cases where outer indexing is applicable are often easier to express with outer indexes than with vectorized indexes.

Since outer indexing is a relatively straightforward generalization of basic indexing with slices, we'll start with that and then work our way through vectorized indexing.

Outer indexing - an extension of slicing

Slicing in our running example:

>>> let t = TrI::new(&[4, 3, 2], &(1..25).collect::<Vec<_>>())
                                                     
                                             
  1   2     7    8  |    13   14     19   20  
  3   4     9    10     15   16     21   22  
  5   6     11   12     17   18     23   24  
                                             
                                                     

>>> t.ix3(1..3, 1..2, ..)
[4, 3, 2] -> [2, 1, 2]
                      ┐
 [ 9  10]   [ 15  16] 
                      ┘
Enter fullscreen mode Exit fullscreen mode

Now, if what we said above about fancy indexing makes sense, expanding the slices to int tensors should give the same result:

>>> let i1 = TrI::new(&[2], &[1, 2])
[ 1  2] // equivalent to slice 1..3

>>> let i2 = TrI::new(&[1], &[1])
[ 1]    // equivalent to slice 1..2

>>> let i3 = TrI::new(&[2], &[0, 1])
[ 0  1] // equivalent to slice ..
Enter fullscreen mode Exit fullscreen mode

And that works!

>>> t.oix3(&i1, &i2, &i3)
[4, 3, 2] -> [2, 1, 2]
                      ┐
 [ 9  10]   [ 15  16] 
                      ┘
Enter fullscreen mode Exit fullscreen mode

As before, the advantage of int tensors is that we are not restricted to regular slices - we can arbitrarily duplicate or rearrange elements. The following example keeps just the first and last element in each dimension, and reverses their order:

>>> let i1 = TrI::new(&[2], &[3, 0])
[ 3  0]

>>> let i2 = TrI::new(&[2], &[2, 0])
[ 2  0]

>>> let i3 = TrI::new(&[2], &[1, 0])
[ 1  0]

>>> t.oix3(&i1, &i2, &i3)
[4, 3, 2] -> [2, 2, 2]
                         
                     
  24   23     6   5  
  20   19     2   1  
                     
                         
Enter fullscreen mode Exit fullscreen mode

And we can still use multi-dimensional indexers to change the shape:

>>> let i1 = TrI::new(&[2, 2], &[3, 3, 0, 0])
       
 3   3 
 0   0 
       

>>> let i2 = TrI::new(&[2], &[2, 0])
[ 2  0]

>>> let i3 = TrI::new(&[2, 2], &[1, 0, 1, 0])
       
 1   0 
 1   0 
       

>>> t.oix3(&i1, &i2, &i3)
[4, 3, 2] -> [2, 2, 2, 2, 2]
                                                           
                                                       
                                               
   24   23     20   19       6   5     2   1   
   24   23     20   19       6   5     2   1   
                                               
                                               
   24   23     20   19       6   5     2   1   
   24   23     20   19       6   5     2   1   
                                               
                                                       
                                                           
Enter fullscreen mode Exit fullscreen mode

Pretty crazy stuff.

What's crazier is that this is not what NumPy does. From the NumPy docs:

Advanced (fancy -ed) indices always are broadcast and iterated as one

What the hell does that mean?

Vectorized indexing - the NumPy way

I'll first give a quick refresher on broadcasting, which you can skip if you're familiar with it, and then move on to vectorized indexing.

Broadcasting refresher

Broadcasting refers to what tensor libraries do when they apply binary element-wise operations to two tensors with different shapes.

For example, it's clear what to do if you want to element-wise multiply a tensor with shape [4, 3] with another tensor of the same shape: you multiply each element in the left tensor with each element in the right tensor. The shape of the result is unchanged at [4, 3]. Likewise, it's intuitive what happens when you want to multiply a singleton tensor of shape [1] with any other tensor: multiply the single element on the left with every element in the tensor on the right. The resulting shape is whatever the shape of the tensor on the right.

Broadcasting formalizes and generalizes these cases, based on two rules that are applied when two tensors are not the same shape:

  1. If the number of dimensions differs, add size dimensions of size 1 at the start of the shape of the tensor with fewer dimensions.
  2. Looking at the dimension sizes in pairs, if the sizes are the same or one of them is size 1, then the tensors can be broadcasted together. The resulting shape is the pairwise maximum of the input shapes.

A nice way to write this is to right-align the shapes:

lhs [3, 4, 5]
rhs [4, 5]

// rule 1
lhs [3, 4, 5]
rhs [1, 4, 5]

// rule 2
result [3, 4, 5]
Enter fullscreen mode Exit fullscreen mode

Align the input tensors on the right, and add 1s to the front. Now in each column, there must be a 1, or the sizes must be equal. If so, the resulting shape is just the maximum of each pair. Otherwise, the shapes are not broadcast-able.

From outer to vectorized indexing

Let's compare the outputs of vectorized with outer indexing.

>>> let t = TrI::new(&[3, 3], &(1..10).collect::<Vec<_>>())
           
 1   2   3 
 4   5   6 
 7   8   9 
           

>>> let i1 = TrI::new(&[2], &[0, 2])
[ 0  2]

>>> let i2 = TrI::new(&[2], &[0, 2])
[ 0  2]
Enter fullscreen mode Exit fullscreen mode

We'll index the 3-by-3 matrix t with i1 and i2. Both index tensors select the first and last element in their respective dimension. With oix, we get the four corners of the matrix:

>>> t.oix2(&i1, &i2)
[3, 3] -> [2, 2]
       
 1   3 
 7   9 
       
Enter fullscreen mode Exit fullscreen mode

With vix however:

>>> t.vix2(&i1, &i2)
[3, 3] -> [2]
[ 1  9]
Enter fullscreen mode Exit fullscreen mode

What gives?

One way to think about this is in terms of coordinates. For outer indexing, we're selecting the coordinates in the matrix t as follows:

                 
 (0, 0)   (0, 2) 
 (2, 0)   (2, 2) 
                 
Enter fullscreen mode Exit fullscreen mode

Which is, in row-major order, the cartesian product of the indexes in both tensors [0, 2] × [0, 2].

For vectorized indexing, the coordinates end up being:

[ (0, 0)  (2, 2)]
Enter fullscreen mode Exit fullscreen mode

That's what "broadcasted together and iterated as one" means: the indexes are zipped together one element at a time, subject to the rules of broadcasting, and then used as if they were one combined index.

Note that this also illustrates the greater expressive power of vectorized indexing: it's not possible to get just the upper left and lower right corners of a matrix using outer indexing. You have to get all four of them. Using vectorized indexing, you can still get the four corners, but you have to work a bit harder for it:

>>> let i1 = TrI::new(&[2, 2], &[0, 0, 2, 2])
       
 0   0 
 2   2 
       

>>> let i2 = TrI::new(&[2, 2], &[0, 2, 0, 2])
       
 0   2 
 0   2 
       

>>> t.vix2(&i1, &i2)
[3, 3] -> [2, 2]
       
 1   3 
 7   9 
       
Enter fullscreen mode Exit fullscreen mode

If you use vectorized indexing with two or more tensors that can't be broadcasted together, you'll get a panic in Tensorken and an exception in NumPy.

One ambiguity with vectorized indexing remains. Let's return to our original running example with 3 dimensions.

>>> let t = TrI::new(&[4, 3, 2], &(1..25).collect::<Vec<_>>())
                                                     
                                             
  1   2     7    8      13   14     19   20  
  3   4     9    10     15   16     21   22  
  5   6     11   12     17   18     23   24  
                                             
                                                     

>>> let i1 = TrI::new(&[2], &[0, 2])
[ 0  2]

>>> let i2 = TrI::new(&[2], &[0, 1])
[ 0  1]

>>> t.vix3(&i1, .., &i2)
[4, 3, 2] -> ???
Enter fullscreen mode Exit fullscreen mode

If we index the first and the third dimensions, and with vectorized indexing we broadcast the tensors together, how do we construct the shape of the resulting tensor? With outer indexing, because all the indexers are independent, we can insert the indexing tensor's shape in the original shape at the position where the indexing tensor appears:

>>> t.oix3(&i1, .., &i2)
[4, 3, 2] -> [2, 3, 2]
Enter fullscreen mode Exit fullscreen mode

The first and last dimensions are indexed by a tensor with size 2, so both their sizes become 2 in the result. The second dimension is untouched so its size is unchanged.

But for vectorized indexing we end up with a shape [2], because broadcast_shape([2], [2]) == [2]. We still need to leave the second dimension untouched...so do we end up with shape [2, 3] or shape [3, 2]?

Here's where Tensorken diverges from Numpy. Tensorken always inserts the indexed dimensions at the front. So the result is:

>>> t.vix3(&i1, .., &i2)
[4, 3, 2] -> [2, 3]
              ┐
 1    3    5  
 14   16   18 
              ┘
Enter fullscreen mode Exit fullscreen mode

NumPy, on the other hand, has some complicated rules around this - if all the indexed dimensions are consecutive, then it inserts the broadcasted dimensions there. Otherwise, it inserts them at the front like Tensorken. I didn't think this adds much value - it's easy to transpose the result if you need a different shape, and the NumPy rules are often confusing.

Quite a lot to get your head around initially - I recommend starting with only thinking about the shapes first. Try to predict the shape of the result based on the indexed tensor, the kind of indexing, and the indexing tensors. Here are a few examples with minimal explanation to get you going:

let arr = CpuI32::ones(&[5, 6, 7, 8]);
let i1 = &CpuI32::new(&[1], &[0]);
let i2 = &CpuI32::new(&[2], &[0, 1]);

let r = arr.oix4(.., i1, i2, ..);
assert_eq!(r.shape(), &[5, 1, 2, 8]);

let r = arr.oix4(.., i1, .., i2);
assert_eq!(r.shape(), &[5, 1, 7, 2]);
 
assert_eq!(r.shape(), &[2, 5, 7]);

let r = arr.vix4(.., i1, 0, ..);
assert_eq!(r.shape(), &[1, 5, 8]);

let r = arr.vix4(.., i1, .., 0);
assert_eq!(r.shape(), &[1, 5, 7]);
Enter fullscreen mode Exit fullscreen mode

Indexing with masks

The last possible indexing type is with boolean tensors, commonly called masks.

Starting with one-dimensional indexers:

>>> let t = TrI::new(&[4, 3, 2], &(1..25).collect::<Vec<_>>())
                                                     
                                             
  1   2     7    8      13   14     19   20  
  3   4     9    10     15   16     21   22  
  5   6     11   12     17   18     23   24  
                                             
                                                     

>>> let i = TrB::new(&[4], &[false, false, true, false])
[ false  false  true  false]

>>> t.oix1(&i)
[4, 3, 2] -> [1, 3, 2]
             
           
  13   14  
  15   16  
  17   18  
           
             
Enter fullscreen mode Exit fullscreen mode

The effect is unsurprising in this case: the elements with a true mask are kept, the other discarded. The size of the indexed dimension in the result is equal to or smaller than the size in the original tensor t. In this example, we select only one value, because there is only a single true in i.

An important difference with int tensor indexers is that the size of the indexed dimensions can only stay the same - if all mask values are true - or decrease.

Another difference becomes apparent when we index with multi-dimensional masks.

>>> let i = TrB::new(&[3, 2], &[false, false, true, false, true, true])
               
 false   false 
 true    false 
 true    true  
               

>>> t.oix2(.., &i)
[4, 3, 2] -> [4, 3]
              ┐
 3    5    6  
 9    11   12 
 15   17   18 
 21   23   24 
              ┘
Enter fullscreen mode Exit fullscreen mode

We indexed the last two dimensions of t with a mask of shape [3, 2] - the same shape as those two dimensions. Since the mask contains three true values, those two dimensions are removed and replaced with a single dimension of shape [3]. In fact, we have to flatten the tensor, because unlike int indexing, where the indexing shape was always a "rectangle" (n-cube in more dimensions), with bool tensors the true values can make any jagged shape we like. In the example, we select the three elements in the lower left corner:

         
 true    
 true    true  
               
Enter fullscreen mode Exit fullscreen mode

While this makes for some interesting ASCII art, that is not a valid tensor.

It's perhaps easier to understand the point of bool indexing if you know that NumPy has loads of operations to construct masks, i.e. bool tensors from int or float tensors. Tensorken really has only one such operation at the moment: eq. But it's sufficient to illustrate the principle.

>>> let i = t.eq(&TrI::new(&[2], &[1, 2]))
                                               
                                       
  t   t     f   f     f   f     f   f  
  f   f     f   f     f   f     f   f  
  f   f     f   f     f   f     f   f  
                                       
                                               

>>> t.oix1(&i)
[4, 3, 2] -> [2]
[ 1  2]
Enter fullscreen mode Exit fullscreen mode

The first operation creates the mask i to identify all rows that contain [1, 2]. (I abbreviated true and false so it fits the screen better.) Then the indexing operation pulls out the values corresponding to the mask (which is just [1, 2] again, of course).

One difference with NumPy is that Tensorken makes no distinction between vix and oix for masks. NumPy does the same "broadcasted together" thing for masks and int tensors. It seems that this is generally speaking lightly used and not well understood, so I choose not to support it.

In conclusion, let's compare all the different forms of indexing in terms of what they can do to the indexed tensor's shape.

  • Basic indexing can remove a dimension by selecting a single element via a scalar index.
  • Basic indexing via slicing can make any dimension smaller, but not change the number of dimensions, except by adding dimensions of size 1 with NewAxis.
  • Outer indexing can keep the number of dimensions the same or increase them. It can decrease or increase the size of dimensions.
  • Vectorized indexing can change both the number and size of dimensions, but all new dimensions are added at the front.
  • Masks or boolean indexers can only decrease or keep the number of dimensions, never increase. Likewise, they can only decrease or keep the size of dimensions the same.

Implementation

The second half of this post describes the implementation of indexing operations in Tensorken, my from-scratch implementation of a tensor library in Rust. If you're only interested in the semantics of indexing operations, you might as well stop reading right now. If you're up for some Rust to gain a deeper understanding, read on.

Recap - where Tensorken is at

Tensorken has really grown up in the past year. With the addition of indexing, it's functionally fully-featured tensor library now. While unlikely to be fast enough for more than toy models, we'll cross that bridge when we come to it.

Tensorken has:

  • Implementations for CPU and GPU of about 20 fundamental tensor operations such as exp, add, sum, and reshape. These operations are defined in a Rust trait RawTensorOps (formerly RawTensor, see the next section). If we want to support a new target for Tensorken, say Cuda, those are the 20 fundamental operations we have to implement.
  • An automatic differentiation layer built on about 15 fundamental differentiable operations defined in DiffableOps (formerly Diffable, see the next section). All the top-level operations on tensors, including the indexing operations, are built on these differentiable operations. That means that the composed operations are themselves differentiable.

That's where we left it in the last part. To add indexing operations, I first needed to allow bool and int tensors in Tensorken. Previously, I had only worked with tensors containing floating point numbers, using 1.0 and 0.0 as boolean values when needed. For example, the eq operation compared two tensors for equality by returning a float tensor with ones and zeros.

The Big Refactor: Bringing int and bool tensors to Tensorken

I figured the original definitions of RawTensor and Diffable would be easy to extend with other element types, given they had an associated type Elem:

pub trait RawTensor {
    type Elem;
    // fns ...
}
Enter fullscreen mode Exit fullscreen mode

When I actually tried to add types besides float I realized I was wrong. The main problem is apparent when we look at eq. The original signature of eq was:

pub trait RawTensor {
    type Elem;
    fn eq(&self, other: &Self) -> Self;
}
Enter fullscreen mode Exit fullscreen mode

If we also have bool tensors, we'd like to make the result of eq a bool tensor. The ideal signature would be:

pub trait RawTensor {
    type Elem;
    fn eq(&self, other: &Self) -> TRes where TRes=Self<Elem=bool>;
}
Enter fullscreen mode Exit fullscreen mode

Which is not a valid Rust type signature. We can get close with:

pub trait RawTensor {
    type Elem;
    fn eq(&self, other: &Self) -> TRes where TRes: RawTensor<Elem=bool>;
}
Enter fullscreen mode Exit fullscreen mode

But this is not accurate: we have several implementations of RawTensor, one for CPU and another for GPU, and this signature of eq would allow the application of eq to two CPU RawTensors to return a GPU RawTensor or vice versa. That's not what we want.

After trying for weeks to get everything to work, seemingly small inaccuracies like these compound. After a long list of compiler errors, I gave up and changed tack.

I re-read my own post on typed tagless final interpreters, and realized what I should have done from the get-go: introduce a higher-order representation of the RawTensor and Diffable traits, with generic associated types.

For the occasion, I renamed RawTensor to RawTensorOps:

pub trait RawTensorOps {
    type Repr<E: Clone>: Clone;
    // fns omitted
}
Enter fullscreen mode Exit fullscreen mode

We now have a generic associated type Repr<E>. This type is the concrete representation of the particular raw tensor we'll implement - for CpuRawTensor for example, it's a buffer in memory with some information on shape and strides. The associated type is generic on E, the tensor's element type - this can be f32, i32, bool, or any other type we care to support.

The raw tensor operations, represented by function on the RawTensorOps trait, then also become generic on the element type E:

fn exp<E: Float>(t: &Self::Repr<E>) -> Self::Repr<E>;

fn add<E: Num>(lhs: &Self::Repr<E>, rhs: &Self::Repr<E>) -> Self::Repr<E>;
Enter fullscreen mode Exit fullscreen mode

Each individual method now potentially has a separate element type E for each argument. Thus, eq:

fn eq<E: PartialEq + Elem>(lhs: &Self::Repr<E>, rhs: &Self::Repr<E>) -> Self::Repr<bool>;
Enter fullscreen mode Exit fullscreen mode

That's exactly the type signature we wanted.

Once you have tensors with different element types, you want to cast between them. Hence the new addition:

fn cast<EFro: Elem, ETo: CastFrom<EFro> + Elem>(t: &Self::Repr<EFro>) -> Self::Repr<ETo>;
Enter fullscreen mode Exit fullscreen mode

Elem, Num, and Float are relatively uninteresting traits that enable successively more operations on the element types. Those traits ensure we can't call exp on a bool tensor and other nonsense.

As a result of this refactor, every implementation consists of two parts: the representation, what we'd think of as the tensor type, and the implementation, typically a singleton or even void type (meaning it has no instances, it's just a type) which implements the RawTensorOps trait. For example, let's look at the types involved in implementing RawTensorOps on the CPU.

First, the representation type is unchanged from before the refactor. It keeps the buffer with the tensor data, and some extra data that stores shape and other information:

pub struct CpuRawTensor<E> {
    buffer: Arc<Buffer<E>>,
    strider: ShapeStrider,
}
Enter fullscreen mode Exit fullscreen mode

E is the type of element - bool, i32, f32 - I've tried to use the name E to mean the element type.

Then we have the implementation type:

pub enum CpuRawTensorImpl {}
Enter fullscreen mode Exit fullscreen mode

Rust guarantees this type can't be instantiated, which is great because we don't need any instances:

impl RawTensorOps for CpuRawTensorImpl {
    type Repr<E: Clone> = CpuRawTensor<E>;
    // fns omitted
}
Enter fullscreen mode Exit fullscreen mode

This pattern repeats for all other RawTensorOps implementations, and for all DiffableOps implementations.

Finally, we need to tie everything together and provide the user-facing API, where we put handy utility methods built out of the base DiffableOps operations. The Tensor struct is where that happens:

pub struct Tensor<T, E: Clone, I: DiffableOps<Repr<E> = T>>(
    pub(crate) T,
    pub(crate) PhantomData<(E, I)>,
);
Enter fullscreen mode Exit fullscreen mode

Tensor ties the representation type T to the implementation type I via the constraint Repr<E> = T. E is again the element type. Here are some valid instantiations of Tensor with concrete types:

pub type Cpu<E> = Tensor<CpuRawTensor<E>, E, CpuRawTensorImpl>;
pub type Wgpu<E> = Tensor<WgpuRawTensor<'static, E>, E, WgpuRawTensorImpl>;
Enter fullscreen mode Exit fullscreen mode

The actual types Tensorken uses are slightly more complicated because simple operation fusing is done via a RawTensorOps implementation in FuseImpl, but the principle remains the same.

And that is it as far as the big refactor is concerned. Conceptually not much has changed compared to v0.3 - but looking at the diff I had to touch pretty much every single line of code. I tried every trick in the book to make the edit manageable - there was no simple path through refactoring, so I had to break the code hard and work through hundreds of compiler errors. The way the Rust compiler works makes this especially disheartening. For example, the compiler doesn't check borrowing rules until it's happy all your trait types are correct. That leads to several iterations of reducing errors from 100 to 0, only to have another 100 errors pop up.

I did most of the editing manually or using search and replace. I also tried to get Gemini to rewrite some code by giving it an example rewrite and asking it to do the rest, one function at a time. That approach was somewhat successful, but I quickly grew tired of copy-pasting back and forth. I have GitHub CoPilot but it seems incapable of rewriting code - it only generates new code.

Implementing slicing

To keep Tensorken small, the higher-level tensor operations like matrix multiplication are implemented in terms of primitive differentiable operations defined on the DiffableOps trait. DiffableOps has basic operations like exp, add, and mul for calculating, but also a set of slicing and dicing operations to change the shape of tensors. To implement basic indexing, we'll use two in particular:

/// Crop the tensor according to the given limits. Limits are inclusive-exclusive.
pub fn crop(&self, limits: &[(usize, usize)]) -> Self;

/// Reshape the tensor to the given shape.
/// The number of elements must remain the same.
pub fn reshape(&self, shape: &[usize]) -> Self;
Enter fullscreen mode Exit fullscreen mode

For details on how these work, see the first part in the Tensorken series. Here are a few quick examples:

>>> let t = &Tr::new(&[3, 2], &[2.0, 1.0, 4.0, 2.0, 8.0, 4.0])
       
 2   1 
 4   2 
 8   4 
       

>>> t.crop(&[(0, 2), (1, 2)])
   
 1 
 2 
   

>>> t.reshape(&[1, 6])
[ 2  1  4  2  8  4]
Enter fullscreen mode Exit fullscreen mode

crop reduces the size of a tensor to a rectangular subset. reshape changes the shape, but does not change the size. Both operations don't copy the tensor - they only change the view on the underlying data buffer.

The advantage of implementing all operations based on the available operations in DiffableOps is that the resulting indexing operations become differentiable too. That means we can use them in programs that use gradient descent for learning.

We proceed by translating the indexing operations as detailed above, to a few datatypes we'll interpret later on. To begin with, we'll define an IndexSpec which contains a number of IndexElements. We split an index expression like t.ix[..4, 2, NewAxis] into three parts. Each part is an index element:

pub struct IndexSpec {
    axes: Vec<IndexElement>,
}

pub enum IndexElement {
    // A single element in an axis.
    Single(SingleIndex),
    // A range of elements in an axis. The second element is inclusive if the bool is true.
    Slice(SingleIndex, SingleIndex, bool),
    // Create a new axis with size 1.
    NewAxis,
    // Keep the remaining dimensions as is.
    Ellipsis,
}
Enter fullscreen mode Exit fullscreen mode

The different cases should be somewhat clear - they map one-on-one to the different slicing possibilities. SingleIndex is a simple enum that specifies if we start counting from the start or the end:

pub enum SingleIndex {
    Head(usize),
    Tail(usize),
}
Enter fullscreen mode Exit fullscreen mode

Even if we would allow negative i32 like Python does, I'd still translate those to this type: all indexing in Rust is done with usize, so a usize-based representation is easier to work with.

The implementation translates an IndexSpec to a BasicIndexResolution:

struct BasicIndexResolution {
    limits: Vec<(usize, usize)>,
    shape: Vec<usize>,
}
Enter fullscreen mode Exit fullscreen mode

We apply a BasicIndexResolution to a tensor like this:

&tensor.crop(&resolution.limits).reshape(&resolution.shape);
Enter fullscreen mode Exit fullscreen mode

In principle, we could also crop and reshape as we interpret an index element dimension per dimension, but it is cleaner to only use two operations.

The implementation itself is not very interesting. It is off by one hell - dealing with counting from the start, counting from the end, the inclusive-exclusive nature of ranges, and other such shenanigans caused me quite a few headaches.

In broad strokes we iterate over the various elements in the index spec, and append limits and update the shape in BasicIndexResolution as we go along.

The four possible cases of IndexElement are handled as follows:

  • Single: a single index, e.g. t.ix1(3). Limits are updated to only keep the given index. The shape is updated to remove this dimension of size 1.
  • Slice: a range, e.g. t.ix1(1..20). Limits are updated to only keep the range. The shape is updated to the resulting size of the dimension.
  • NewAxis: a new axis of size 1. Limits are unchanged. The shape is updated with a new dimension of size 1. We can always add such a dimension because it doesn't change the overall size of the tensor.
  • Ellipsis: keep remaining dimensions as is. Since ellipsis can occur at most once, but anywhere in an index, this needs to figure out how many dimensions the ellipsis spans, and then add the original limits and shape unchanged.

The full implementation if you're interested in going into more detail. It also handles fancy indexing, which we'll discuss next.

Adding fancy indexing

We now know how to resolve basic indexes, but we can use both basic indexing and fancy indexes in the same index expression. Before we start on the implementation of fancy indexes, we have to figure out how to interleave the two.

As it turns out, resolving basic and fancy indexes can be split into a basic indexing phase and a fancy indexing phase. We can first extract and apply the basic indexing expressions, leaving any dimensions with fancy indexes intact as if the user had specified ... Thanks to the magic of basic indexing, this never needs a copy. Then, we apply any fancy indexing expressions to the result.

>>> let t = TrI::new(&[4, 3, 2], &(1..25).collect::<Vec<_>>())
                                                     
                                             
  1   2     7    8      13   14     19   20  
  3   4     9    10     15   16     21   22  
  5   6     11   12     17   18     23   24  
                                             
                                                     

>>> let i1 = TrI::new(&[2], &[0, 2])
[ 0  2]

>>> let i2 = TrI::new(&[2], &[0, 1])
[ 0  1]

>>> t.vix3(&i1, ..2, &i2)
[4, 3, 2] -> [2, 2]
         
 1    3  
 14   16 
         

>>> t.vix3(.., ..2, ..).vix3(&i1, .., &i2)
[4, 3, 2] -> [2, 2]
         
 1    3  
 14   16 
         
Enter fullscreen mode Exit fullscreen mode

This example splits vix3(&i1, ..2, &i2) into vix3(.., ..2, ..), which has only basic indexing expressions, and vix3(&i1, .., &i2) which has only fancy indexing expressions. Happily, we arrived at the same result, so we can now focus on implementing fancy indexing without worrying about how it interacts with basic indexing.

In the implementation, we extend the IndexElement enum with an additional case:

pub enum IndexElement<I: DiffableOps> {
    // as before
    // Fancy index - mask or int tensor.
    Fancy(Fancy<I>),
}

pub enum Fancy<I: DiffableOps> {
    Full,
    IntTensor(Tensor<I::Repr<i32>, i32, I>),
    BoolTensor(Tensor<I::Repr<bool>, bool, I>),
}
Enter fullscreen mode Exit fullscreen mode

The Full case indicates that the corresponding dimension was handled in the basic indexing phase, which happens first. The fancy indexing phase keeps the dimension unchanged.

Implementing outer indexing

Let's start with outer indexing. It's not obvious how to implement with just the operations in DiffableOps, yet there is a way. It relies on the observation that if we can convert the integer indexes to one-hot vectors - vectors that have a 1 in the position they index, and are 0 otherwise - and then multiply the original tensor with these one-hot vectors, we keep exclusively the elements we want.

Let's look at an example to clarify. We start with the following one-dimensional tensor, and index its only dimension with the tensor i:

>>> let t = TrI::new(&[4], &(1..5).collect::<Vec<_>>())
[ 1  2  3  4]

>>> let i = TrI::new(&[2], &[2, 0])
[ 2  0]
Enter fullscreen mode Exit fullscreen mode

The expected result is [ 3 1].

First, we convert i = [ 2 0] to its corresponding one-hot representation manually (we'll see how to do this automatically later on):


>>> let i_one_hot = // coming soon
       
 0   1 
 0   0 
 1   0 
 0   0 
       
Enter fullscreen mode Exit fullscreen mode

We have two indexes, 2 and 0, so there are two column vectors. The first index is 2, and the one-hot vector for 2 is [0 0 1 0], so that's the first column. The one-hot vector for 0 is [1 0 0 0], so that's the second column. We replaced the numbers in each column with one-hot vectors representing that number.

Now, after reshaping t to a column vector:

>>> let t = t.reshape(&[4, 1])
   
 1 
 2 
 3 
 4 
   
Enter fullscreen mode Exit fullscreen mode

We can see that if we multiply i_one_hot with the reshaped t, thanks to broadcasting, we multiply each column in i_one_hot with the column vector t. This results in a tensor that contains only one non-zero entry per column, and this entry is the entry in the original t we want to keep:

>>> let mul_result = t.mul(&i_one_hot)
       
 0   1 
 0   0 
 3   0 
 0   0 
       
Enter fullscreen mode Exit fullscreen mode

We can smell victory now! All that's left to do is get rid of the zeros. Zero is a neutral element for addition, so we can just sum the columns:

>>> let sum_result = mul_result.sum(&[0]).squeeze(&Axes::Axis(0))
[ 3  1]
Enter fullscreen mode Exit fullscreen mode

And that's exactly what we wanted.

Now, how do we turn a vector of positive indexes into a one-hot representation? It is simple - once you know the trick! Create a range vector as long as the size of the dimension:

>>> let i_range = TrI::new(&[4, 1], (0..4).collect::<Vec<_>>().as_slice())
   
 0 
 1 
 2 
 3 
   
Enter fullscreen mode Exit fullscreen mode

And compare it to the index vector using eq, casting true to 1 and false to 0:

>>> let i_one_hot = i.eq(&i_range).cast::<i32>()
       
 0   1 
 0   0 
 1   0 
 0   0 
       
Enter fullscreen mode Exit fullscreen mode

Voila! We can repeat this procedure dimension by dimension for every outer index.

To understand what happens in more dimensions, it's easier to think in terms of the shapes. For example, let's say we have a tensor t with shape [4, 6] and we index its first dimension with a tensor i of shape [2]. Here's what happens with the shapes:

t [4, 6]
i [2]
t[i, ..]       [2, 6] // expected result shape

// First create the one-hot representation of the index.
range          [4, 1]
// Its shape is the size of the indexed dimension, by the size of i.
one_hot        [4, 2]

// Reshape t - add a size 1 dimension...
t           [4, 1, 6]
// Reshape one_hot - add a size one dimension...
one_hot     [4, 2, 1]
// ...so that the multiplication lines up nicely for broadcasting.
t * one_hot [4, 2, 6]

// Sum and remove the dimension we don't need.
sum         [1, 2, 6]
squeeze        [2, 6]
Enter fullscreen mode Exit fullscreen mode

Pretty neat! Say we wanted to index the second dimension of t with another tensor j, we'd continue from the intermediate result [2, 6] as follows:

t[i, ..]       [2, 6]
j [3]
t[i, j]        [2, 3]  // expected result shape

// First create the one-hot representation of the index.
range          [6, 1]
// Its shape is the size of the indexed dimension of t, by the size of i.
one_hot        [6, 3]

// Reshape t - add a size 1 dimension after the indexed dimension
t           [2, 6, 1]
// Reshape one_hot not necessary for the last dimension - 
// broadcasting adds dimensions at the front automatically.
one_hot        [6, 3]
// This multiplication lines up nicely.
t * one_hot [2, 6, 3]

// Sum and remove the dimension we don't need.
sum         [2, 1, 3]
squeeze        [2, 3]
Enter fullscreen mode Exit fullscreen mode

You can equivalently think of indexing with an int tensor as matrix multiplication with a one-hot representation of the indexing tensor.

In summary, to outer index with an int tensor:

  • convert the int tensor to a one-hot representation
  • reshape the tensor to make room for the shape of the indexing tensor by adding dimensions of size 1
  • multiply with the one hot representation
  • reduce the original dimension by summing it.

The Rust implementation is here.

Implementing vectorized indexing

Vectorized indexing works via the same principle as outer indexing, but in vectorized indexing the indexing tensors are broadcasted together and the new dimensions are added at the front. We can still play the one hot, multiply, and sum game, but we adjust the shapes differently.

Let's go through an example again. We'll use the same tensor t but with indexing tensors i and j this time.

>>> let t = TrI::new(&[4, 6], &(1..25).collect::<Vec<_>>())
                             
 1    2    3    4    5    6  
 7    8    9    10   11   12 
 13   14   15   16   17   18 
 19   20   21   22   23   24 
                             

>>> let i = TrI::new(&[2], &[2, 0])
[ 2  0]

>>> let j = TrI::new(&[2], &[1, 0])
[ 1  0]
Enter fullscreen mode Exit fullscreen mode

Using the same i and j as in the previous section won't work: their shapes can't be broadcasted together.

We again proceed iteratively, starting with i. Let's first create the one hot representation:

>>> let i_range = TrI::new(&[4], (0..4).collect::<Vec<_>>().as_slice())
[ 0  1  2  3]

>>> let i = i.reshape(&[2, 1])
   
 2 
 0 
   

>>> let i_one_hot = i.eq(&i_range).cast::<i32>()
               
 0   0   1   0 
 1   0   0   0 
               
Enter fullscreen mode Exit fullscreen mode

The representation is transposed from the oix implementation: we now have two row vectors stacked on top of each other. That's because the new dimensions added by indexing now always go to the front of the resulting tensor, and as we'll see in the next step this transposed one hot representation works out better:

>>> let i_one_hot = i_one_hot.reshape(&[2, 4, 1])
               
           
  0     1  
  0     0  
  1     0  
  0     0  
           
               

>>> let mul_result = t.mul(&i_one_hot)
                                                             
                                                         
  0    0    0    0    0    0      1   2   3   4   5   6  
  0    0    0    0    0    0      0   0   0   0   0   0  
  13   14   15   16   17   18     0   0   0   0   0   0  
  0    0    0    0    0    0      0   0   0   0   0   0  
                                                         
      
Enter fullscreen mode Exit fullscreen mode

We've reshaped the one hot representation to [2, 4, 1] and multiply that with t of shape [4, 6]. That multiplication results in a [2, 4, 6] tensor. The dimension of size 4 is the dimension we're indexing, so that's the one we need to get rid of:

>>> let t = mul_result.sum(&[1]).squeeze(&Axes::Axis(1))
                             
 13   14   15   16   17   18 
 1    2    3    4    5    6  
                             
Enter fullscreen mode Exit fullscreen mode

And that's the first index done. We're left with an intermediate result t of shape [2, 6]. So far, the result is the same as if we'd used oix - the difference becomes visible when we handle the second index. We make the one hot representation of j, again transposed when compared to outer indexing:

>>> let j_range = TrI::new(&[6], (0..6).collect::<Vec<_>>().as_slice())
[ 0  1  2  3  4  5]

>>> let j = j.reshape(&[2, 1])
   
 1 
 0 
   

>>> let j_one_hot = j.eq(&j_range).cast::<i32>()
                       
 0   1   0   0   0   0 
 1   0   0   0   0   0 
                       
Enter fullscreen mode Exit fullscreen mode

Now we have a one-hot shape of [2, 6] and a tensor of shape [2, 6].We can multiply them without any further reshaping. This step is where we use the requirement that the indexing tensors must be broadcast-able: if not, we'd fail because the one-hot shape (determined by the second indexing tensor) and the intermediate result shape (determined by the first indexing tensor) wouldn't line up.

>>> let mul_result = t.mul(&j_one_hot)
                        ┐
 0   14   0   0   0   0 
 1   0    0   0   0   0 
                        ┘
Enter fullscreen mode Exit fullscreen mode

Finally, we reduce the first axis to get the result:

>>> let sum_result = mul_result.sum(&[1]).squeeze(&Axes::Axis(1))
[ 14  1]
Enter fullscreen mode Exit fullscreen mode

Putting everything together in terms of shapes:

t [4, 6]
i [2]
t[i, ..]       [2, 6]  // expected result shape

// First create the one-hot representation of the index.
range             [4]
// Its shape is the size of the indexed dimension of t, by the size of i.
one_hot        [2, 4]

// Reshape one_hot...
one_hot     [2, 4, 1]

// multiplication lines up nicely.
t * one_hot [2, 4, 6]

// Sum and remove the dimension we don't need.
sum         [2, 1, 6]
squeeze        [2, 6]

// Continue with the second index
j                 [2]
t[i, j]           [2]  // expected result shape

// First create the one-hot representation of the index.
range             [6]
// Its shape is the size of the indexed dimension of t, the by size of i.
one_hot        [2, 6]

// multiplication lines up nicely.
t * one_hot    [2, 6]

// Sum and remove the dimension we don't need.
sum            [2, 1]
squeeze           [2]
Enter fullscreen mode Exit fullscreen mode

The Rust implementation is here.

Implementing masking

The final piece is masking, or indexing with a boolean tensor. There are no new tricks here - the best we can do (as far as I can tell) is manually convert the bool tensor to a one-dimensional vector, convert that to the equivalent int tensor, and then index using the int tensor.

For example, a bool tensor:

>>> let b = TrB::new(&[2, 3], &[false, false, true, false, true, false])
                       
 false   false   true  
 false   true    false 
                       
Enter fullscreen mode Exit fullscreen mode

Is turned into:

>>> let i_b = TrI::new(&[2], &[2, 4])
[ 2  4]
Enter fullscreen mode Exit fullscreen mode

Since the b mask had two dimensions, it spans two dimensions when used as an index, and those two dimensions reduce to one. To achieve this, we reshape the dimensions to a single one, with a size equal to the product of the size of the original dimensions. Then we use outer indexing to index with the equivalent int index.

The implementation is here.

Conclusion

I hope that demystified tensor indexing. There's a lot more to it than meets the eye. Check out the references for useful additional material if you want to learn more.

Thank you for reading Get Code. This post is public so feel free to share it.

Share

References

Top comments (0)