Table of Contents
- Why learn
Einops? - Why read this guide instead of the official
Einopstutorial - What's in the tutorial
- Transposing tensors
- Composition of axes
- Decomposition of axis
Einopsreduce- Repeating elements
- Einsum
- Summary of linear algebra operations using
Einsum
Why learn Einops?
Einops provides a way for tensor manipulations to be self-documenting. Key advantages of Einops over traditional methods such as reshape include:
- Reliable: In addition to doing what is expected, it also explicitly fails for wrong inputs
-
Readable: Codes with
Einopsare self-documenting and hence easier to read without requiring additional comments - Maintainable: readable codes are more maintainable
There are also other claimed benefits such as time and memory efficiency which is not touched upon in this article.
For example, given a tensor of shape (2, 3, 4) called encode_weights:
# encode_weights with shape (2, 3, 4)
encode_weights = np.array([
[[0.1, 0.2, 0.3, 0.4], # pos=0, dim1=0
[0.5, 0.6, 0.7, 0.8], # pos=0, dim1=1
[0.9, 1.0, 1.1, 1.2]], # pos=0, dim1=2
[[2, 3, 4, 5], # pos=1, dim1=0
[6, 7, 8, 9], # pos=1, dim1=1
[1, 0, 1, 0]] # pos=1, dim1=2
])
And we want to flatten it while preserving the first dimension into the following:
[[0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1. 1.1 1.2]
[2. 3. 4. 5. 6. 7. 8. 9. 1. 0. 1. 0. ]]
Previously without Einops, we could use reshape as below to achieve the intended result. Here's how to interpret the code:
-
Keep the first dimension unchanged:
encode_weights.shape[0]preserves the size of the first axis. -
Flatten all remaining dimensions: The
-1tells NumPy/PyTorch to automatically calculate the size needed to flatten all other dimensions into a single dimension.
encode_weights.reshape(encode_weights.shape[0], -1)
However, when manipulating tensors, we are primarily concerned with inputs and outputs, and not really how the operations are carried out to achieve the inputs and outputs. From the code, it is not clear what are the shapes of the input and output tensors, which will have to be separately documented with comments. This could present a problem in a large code base where code drifts could happen, and outdated comments could confuse instead of help.
In contrast, we could use the rearrange operation in Einops to have a self-documenting way of manipulating the tensor to achieve the same outcome. From the code, we can see that:
-
encode_weightshas 3 dimensions, and is of shape (2, 3, 4) - The output has 2 dimensions, and is of shape (2, 3 * 4)
- The code should fail if the dimensions specified are wrong, e.g. code drifts.
rearrange(encode_weights, "pos dim1 dim2 -> pos (dim1 dim2)", pos=2, dim1=3, dim2=4)
Why read this guide instead of the official Einops tutorial
Einops has a great multi-part tutorial covering all the key features, and this guide serves to complement the official Einops tutorials, not to replace these tutorials.
While the use of the e i n o p s tensors as examples may be useful for fields such as computer vision, I found that I could learn better by being able to trace numbers in tensors visually.
Being a complement, I will largely follow the structure of Einops tutorial part 1, but using small tensors with numbers instead of images.
What's in the tutorial
I will use two tensors, embeddings and encode_weights to demonstrate the following operations: rearrange, reduce, repeat, and einsum. I will also link the operations back to concepts in linear algebra.
import numpy as np
from einops import rearrange
# New embeddings with shape (1, 2, 3)
embeddings = np.array([
[[1, 2, 3], # pos=0
[4, 5, 6]] # pos=1
])
# encode_weights with shape (2, 3, 4)
encode_weights = np.array([
[[0.1, 0.2, 0.3, 0.4], # pos=0, dim1=0
[0.5, 0.6, 0.7, 0.8], # pos=0, dim1=1
[0.9, 1.0, 1.1, 1.2]], # pos=0, dim1=2
[[2, 3, 4, 5], # pos=1, dim1=0
[6, 7, 8, 9], # pos=1, dim1=1
[1, 0, 1, 0]] # pos=1, dim1=2
])
Transposing tensors
In linear algebra, the transpose of a matrix is an operator which flips a matrix over its diagonal, i.e. switching the row and column indices of matrix by producing another matrix, often denoted by . Wikipedia has a good animation to illustrate this operation:
Source: Wikipedia
Using the embedding tensor as an example, it has one batch, two pos, and dim1 of three, we can transpose it along pos and dim1 by switching the two axes using rearrange:
print("original:")
print(embeddings)
print("transposed:")
print(rearrange(embeddings, "batch pos dim1 -> batch dim1 pos", batch=1, pos=2, dim1=3))
We get the following output:
original:
[[[1 2 3]
[4 5 6]]]
transposed:
[[[1 4]
[2 5]
[3 6]]]
We can also switch the axes batch and pos, so now we get two batch-es and one pos:
print("original:")
print(embeddings)
print("transposed:")
print(rearrange(embeddings, "batch pos dim1 -> pos batch dim1", batch=1, pos=2, dim1=3))
We get the following output:
original:
[[[1 2 3]
[4 5 6]]]
transposed:
[[[1 2 3]]
[[4 5 6]]]
We can verify that transpose of embeddings transposed returns us the original embeddings tensor:
embeddings_transposed = rearrange(embeddings, "batch pos dim1 -> batch dim1 pos", batch=1, pos=2, dim1=3)
embeddings_returned = rearrange(embeddings_transposed, "batch dim1 pos -> batch pos dim1", batch=1, pos=2, dim1=3)
print(embeddings_returned)
The output:
[[[1 2 3]
[4 5 6]]]
Composition of axes
Axis composition means combining multiple axes into one. The order of composition matters, which I will also demonstrate.
We can use round brackets ( and ) to compose two axes into a single axis. Using the encode_weights tensor as an example, the following code manipulates the original batch of two, pos of three and dim1 of four to (batch * pos) of six, and dim1 of four.
print("original:")
print(encode_weights)
print("Composing batch and pos dimensions to a new height dimension:")
print(rearrange(encode_weights, "batch pos dim1 -> (batch pos) dim1", batch=2, pos=3, dim1=4))
The output is as below:
original:
[[[0.1 0.2 0.3 0.4]
[0.5 0.6 0.7 0.8]
[0.9 1. 1.1 1.2]]
[[2. 3. 4. 5. ]
[6. 7. 8. 9. ]
[1. 0. 1. 0. ]]]
Composing batch and pos dimensions to a new height dimension:
[[0.1 0.2 0.3 0.4]
[0.5 0.6 0.7 0.8]
[0.9 1. 1.1 1.2]
[2. 3. 4. 5. ]
[6. 7. 8. 9. ]
[1. 0. 1. 0. ]]
If we swap the new axis (batch * pos) with (pos * batch), we will instead interleave each pos with each batch. An intuitive way to understand this operation is to break up the operation into two steps:
-
Step 1: Transpose
batchwithpos -
Step 2: Build the new composite axis
(pos * batch)
print("original:")
print(encode_weights)
print("Step 1: transpose `batch` with `pos`")
encoded_weights_transposed = rearrange(encode_weights, "batch pos dim1 -> pos batch dim1", batch=2, pos=3, dim1=4)
print(encoded_weights_transposed)
print("Step 2: build the new composite axis (`pos` * `batch`)")
print(rearrange(encoded_weights_transposed, "pos batch dim1 -> (pos batch) dim1", batch=2, pos=3, dim1=4))
With the following as output:
original:
[[[0.1 0.2 0.3 0.4]
[0.5 0.6 0.7 0.8]
[0.9 1. 1.1 1.2]]
[[2. 3. 4. 5. ]
[6. 7. 8. 9. ]
[1. 0. 1. 0. ]]]
Step 1: transpose `batch` with `pos`
[[[0.1 0.2 0.3 0.4]
[2. 3. 4. 5. ]]
[[0.5 0.6 0.7 0.8]
[6. 7. 8. 9. ]]
[[0.9 1. 1.1 1.2]
[1. 0. 1. 0. ]]]
Step 2: build the new composite axis (`pos` * `batch`)
[[0.1 0.2 0.3 0.4]
[2. 3. 4. 5. ]
[0.5 0.6 0.7 0.8]
[6. 7. 8. 9. ]
[0.9 1. 1.1 1.2]
[1. 0. 1. 0. ]]
We can also have a single composite axis across all three axes, i.e.:
print(rearrange(encode_weights, "batch pos dim1 -> (batch pos dim1)", batch=2, pos=3, dim1=4))
This returns a vector of dimension 1 X (2 * 3 * 4) or simply 1 X 24, as seen from the output below:
[0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1. 1.1 1.2 2. 3. 4. 5. 6. 7.
8. 9. 1. 0. 1. 0. ]
Decomposition of axis
As compared to axis composition which involves combining axes, axis decomposition breaks up a single axis into multiple axes.
The example below takes the original dim1, and breaks it up into dim1 and dim2. Note that we have reused the original dim1 of length 4 for this illustration, where dim1 is only of length 2. This is an illustration of how the names of dimensions in Einops are local to each Einops operation.
print("original:")
print(encode_weights)
print("Break-up dim into two axes")
print(rearrange(encode_weights, "batch pos (dim1 dim2) -> batch pos dim1 dim2", batch=2, pos=3, dim1=2, dim2=2))
Has the following output:
original:
[[[0.1 0.2 0.3 0.4]
[0.5 0.6 0.7 0.8]
[0.9 1. 1.1 1.2]]
[[2. 3. 4. 5. ]
[6. 7. 8. 9. ]
[1. 0. 1. 0. ]]]
Break-up dim into two axes
[[[[0.1 0.2]
[0.3 0.4]]
[[0.5 0.6]
[0.7 0.8]]
[[0.9 1. ]
[1.1 1.2]]]
[[[2. 3. ]
[4. 5. ]]
[[6. 7. ]
[8. 9. ]]
[[1. 0. ]
[1. 0. ]]]]
For completeness, it is illegal to decompose pos1=3 to pos1=1 and pos2=2. The reason is because the multiplication of the decomposed axes must equal the original axis, and hence we receive the error of 3 != 2.
print("original:")
print(encode_weights)
print("Break-up pos into two axes")
print(rearrange(encode_weights, "batch (pos1 pos2) dim1 -> batch pos1 pos2 dim1", batch=2, pos1=1, pos2=2, dim1=4))
Output is an error:
EinopsError: Error while processing rearrange-reduction pattern "batch (pos1 pos2) dim1 -> batch pos1 pos2 dim1".
Input tensor shape: (2, 3, 4). Additional info: {'batch': 2, 'pos1': 1, 'pos2': 2, 'dim1': 4}.
Shape mismatch, 3 != 2
Einops reduce
A reduce operation is a computation that aggregates multiple values from a collection into a single result. These operations include min, max, sum, mean, prod, any, all. In Einops, the missing axes are reduced.
The following reduce operation of batch1 pos1 dim1 -> reduces the tensor into a single scalar, and the operation max returns the maximum value.
from einops import reduce
print("original:")
print(encode_weights)
print("Return the max value as a scalar")
print(reduce(encode_weights, "batch1 pos1 dim1 -> ", "max", batch1=2, pos1=3, dim1=4))
Which gives us the following output, i.e. 9.0:
original:
[[[0.1 0.2 0.3 0.4]
[0.5 0.6 0.7 0.8]
[0.9 1. 1.1 1.2]]
[[2. 3. 4. 5. ]
[6. 7. 8. 9. ]
[1. 0. 1. 0. ]]]
Return the max value as a scalar
9.0
Note that instead of returning a scalar, we can retain the three dimensional tensor with either () or 1, which represents an axis of length 1. Both () and 1 are equivalent.
from einops import reduce
print("original:")
print(encode_weights)
print("Return the max value as a scalar:")
print(reduce(encode_weights, "batch1 pos1 dim1 -> 1 1 1", "max", batch1=2, pos1=3, dim1=4))
print("Assert that `()` and `1` are equivalent:")
print(np.array_equal(reduce(encode_weights, "batch1 pos1 dim1 -> 1 1 1", "max", batch1=2, pos1=3, dim1=4), reduce(encode_weights, "batch1 pos1 dim1 -> () () ()", "max", batch1=2, pos1=3, dim1=4)))
original:
[[[0.1 0.2 0.3 0.4]
[0.5 0.6 0.7 0.8]
[0.9 1. 1.1 1.2]]
[[2. 3. 4. 5. ]
[6. 7. 8. 9. ]
[1. 0. 1. 0. ]]]
Return the max value as a scalar:
[[[9.]]]
Assert that `()` and `1` are equivalent:
True
Repeating elements
As the name suggests, repeat means repeating elements in the tensor. I will demonstrate the following variations:
- Repetition along a new axis
- Repetition along an existing axis
- Interleaving of repetition
- Representing repetition as an integer instead of variable name
Repetition along a new axis
In the following code, we introduced a new axis repeat, which transforms embeddings from shape of (1, 2, 3) to (1, 2, 3, 4). Note that the order the new axis repeat matters -- by placing repeat of size two right after dim1, we are as converting each element in dim1 into an array of length 2 two, where the element is repeated twice.
print("original:")
print(embeddings)
print("Repeat with new axis:")
repeat(embeddings, "batch pos dim1 -> batch pos dim1 repeat", batch=1, pos=2, dim1=3, repeat=2)
The output is below.
original:
[[[1 2 3]
[4 5 6]]]
Repeat with new axis:
array([[[[1, 1],
[2, 2],
[3, 3]],
[[4, 4],
[5, 5],
[6, 6]]]])
Repetition along an existing axis
We can repeat along the axis dim1 by introducing composition dim1 and repeat into a single axis (dim1 repeat).
print("Tensor from previous step:")
print(repeat(embeddings, "batch pos dim1 -> batch pos dim1 repeat", batch=1, pos=2, dim1=3, repeat=2))
print("Repeat with new axis:")
repeat(embeddings, "batch pos dim1 -> batch pos (dim1 repeat)", batch=1, pos=2, dim1=3, repeat=2)
The output is as follows:
Tensor from previous step:
[[[[1 1]
[2 2]
[3 3]]
[[4 4]
[5 5]
[6 6]]]]
Repeat with new axis:
array([[[1, 1, 2, 2, 3, 3],
[4, 4, 5, 5, 6, 6]]])
Interleaving of repetition
The order of composition matters, and by swapping the composite axis of (dim1 repeat) with (repeat dim1), we can interleave the the new axis repeat with dim1.
print("Tensor from previous step:")
print(repeat(embeddings, "batch pos dim1 -> batch pos (dim1 repeat)", batch=1, pos=2, dim1=3, repeat=2))
print("Repeat with new axis:")
repeat(embeddings, "batch pos dim1 -> batch pos (repeat dim1)", batch=1, pos=2, dim1=3, repeat=2)
The output is as follows:
Tensor from previous step:
[[[1 1 2 2 3 3]
[4 4 5 5 6 6]]]
Repeat with new axis:
array([[[1, 2, 3, 1, 2, 3],
[4, 5, 6, 4, 5, 6]]])
Representing repetition as an integer instead of variable name
Instead of representing the new axis as repeat=2, we can simply use the integer 2 to achieve the same effect:
print("Tensor from previous step:")
repeat(embeddings, "batch pos dim1 -> batch pos (repeat dim1)", batch=1, pos=2, dim1=3, repeat=2)
print("Repeat with integer 2:")
repeat(embeddings, "batch pos dim1 -> batch pos (2 dim1)", batch=1, pos=2, dim1=3)
The output shows that both operations give the same result:
Tensor from previous step:
array([[[1, 2, 3, 1, 2, 3],
[4, 5, 6, 4, 5, 6]]])
Repeat with integer 2:
array([[[1, 2, 3, 1, 2, 3],
[4, 5, 6, 4, 5, 6]]])
Einsum
Einsum is a function that implements the Einstein summation convention for performing multi-dimensional array operations. To understand Einsum, I begin with a brief introduction of Einstein Notation, which provides a shorthand way to express instructions for combining numbers.
Let us begin with an example
Imagine you have two vectors where a is [1,2,3] and b is [4,5,6]. You want to multiply the first number in a with the first in b, the second with the second, and so on and so forth, which can be expressed as:
This is also known as a dot product. We can also express the summation using the Sigma symbol :
This means: "For every number i from 1 to 3, take the i-th element of a and the i-th element of b, multiply them, and then add everything up."
Note that the summation almost always happens over an index that appears twice, which we can further shorten using Einstein Notation.
Two rules of Einstein Notation
Rule 1: The Repeating Index means "Sum"
If you see an index (a little letter) appears twice on one side of the equation, it means you should sum over all possible values of that index. This repeated index is called a "dummy index".
Let us first rewrite the vectors, with vector a as a row vector (with a subscript representing "covariant"), and vector b as a column vector (with a superscript representing "contravariant"). Note that in a linear space, vectors retain their values after such transformations, which may not necessarily be true in non-linear spaces.
So, our long sum from before:
Becomes in Einstein Notation:
Because the index i appears twice (once with a and once with b), the summation is implied.
Example: Dot Product of two vectors
Let's use our vectors from before. We'll write the row vector with a subscript and the column vector with a superscript.
The Einstein Notation is:
This tells you to sum over the index i:
We represent the operation as python code below:
from einops import einsum
A = np.array([
[1,2,3]
])
B = np.array([
[4],
[5],
[6]
])
einsum(A, B, "i j, j i-> ")
The einsum operation returns a scalar of value 32, which is expected:
np.int64(32)
Rule 2: The "Free" Index informs the output's shape
If an index appears only once on one side, it's called a "free index". This index must also appear on the other side of the equation. It tells you the shape or dimension of the result.
Example: Matrix-Vector Multiplication
Let's take a matrix A and a vector v.
, which we can write as .
, which we'll write with a subscript, .
The result of their multiplication is a new vector, let's call it u. The standard equation is:
The Einstein Notation for this:
- We see the index
jis repeated (once up, once down). So, we sum overj. - We see the index
iappears only once on the right side as a superscript. This means it's a free index, and the result,u, will be a column vector with the superscripti.
Let's calculate the first element of u, which is
. This is where i=1:
Now for the second element,
, where i=2:
So, the resulting column vector is .
So, the resulting vector is .
The python code representing the above operation is below:
from einops import einsum
A = np.array([
[1,2],
[3,4]
])
B = np.array([
[5],
[6]
])
einsum(A, B, "i j, j i-> i")
Which returns the resulting two-element vector as expected:
array([17, 39])
More Examples
Matrix-Matrix Multiplication
Let's multiply two matrices, A and B, to get a new matrix C.
, .
The Einstein Notation is:
- The index
jis repeated (summed over). - The indices
i(upstairs) andk(downstairs) are free, which tells us the resultCwill be a matrix with a superscriptiand subscriptk.
Let's calculate one element, say
(where i=1 and k=1):
How about
(where i=2 and k=1):
If you calculate all the elements, you get the final matrix C.
The above operation can be represented as code below:
from einops import einsum
A = np.array([
[1,2],
[3,4]
])
B = np.array([
[5, 6],
[7, 8]
])
einsum(A, B, "i j, j k-> i k")
Trace of a Matrix (Sum of diagonal elements)
The trace of a matrix is the sum of its diagonal elements.
For our matrix
, the trace is
.
The Einstein Notation is:
Here, the index i is repeated (once up, once down), so it means sum over i.
By setting the superscript and the subscript to be the same dummy index, we are telling it to only look at the diagonal elements and sum them up.
We can implement trace using code:
from einops import einsum
A = np.array([
[1,2],
[3,4]
])
trace = einsum(A, "i i ->")
print(f"The matrix A is:\n{A}")
print(f"The trace of A is: {trace}")
Which returns the following output:
The matrix A is:
[[1 2]
[3 4]]
The trace of A is: 5
How the code works:
-
einsumlooks at the input matrixAand the string"i i ->". - It identifies the elements of
Awhere the first and second indices are identical:A[0,0]andA[1,1]. - These elements are
1and4. - Because the output is specified as a scalar (
->), it sums these elements together:1 + 4 = 5.
The result is 5, which is the correct trace of the matrix A.
Summary of linear algebra operations using Einsum
ajcr has a nice summary of how to replicate the more common linear algebra operations using Einsum, which I reproduce below. The original blog post is available here. Note that ajcr assumes the numpy implementation of einsum. If you are using the einops variant, remember to leave a space between each index, e.g. i,ij->i in numpy should be represented as i,i j->i in einops.
Let A and B be two 1D arrays of compatible shapes (meaning the lengths of the axes we pair together either equal, or one of them has length 1):
Now let A and B be two 2D arrays with compatible shapes:
Note that we can use the spread operator '...' to conveniently represent axes which we are not interested in, e.g. np.einsum('...ij,ji->...', a, b) will multiply just the last two axes of a with the 2D array b.



Top comments (0)