DEV Community

Super Kai (Kazuya Ito)
Super Kai (Kazuya Ito)

Posted on • Updated on

take(), take_along_dim() and gather() in PyTorch

take() can take one or more elements from a 0D or more D tensor using a 0D or more D tensor as shown below:

*Memos:

  • take() can be called both from torch and a tensor.
  • The 2nd argument(Required) is a 0D or more D tensor with torch.
  • The 1st argument(Required) is a 0D or more D tensor with a tensor.
  • The tensor of the 2nd argument with torch or the 1st argument with a tensor decides the size of a returned tensor.
import torch

my_tensor = torch.tensor([[9, 5, 0, 6, 2], [7, 1, 3, 4, 8]])

torch.take(my_tensor, torch.tensor(3))
my_tensor.take(torch.tensor(3))
torch.take(my_tensor, torch.tensor(-7))
# tensor(6)

torch.take(my_tensor, torch.tensor([3, 0, 7, 4]))
torch.take(my_tensor, torch.tensor([-7, -10, -3, -6]))
# tensor([6, 9, 3, 2])

torch.take(my_tensor, torch.tensor([[3, 0], [7, 4]]))
torch.take(my_tensor, torch.tensor([[-7, -10], [-3, -6]]))
# tensor([[6, 9], [3, 2]])

torch.take(my_tensor, torch.tensor([[[3, 0], [7, 4]],
                                    [[8, 2], [3, 5]]]))
torch.take(my_tensor, torch.tensor([[[-7, -10], [-3, -6]],
                                    [[-2, -8], [-7, -5]]]))
# tensor([[[6, 9], [3, 2]], [[4, 0], [6, 7]]])
Enter fullscreen mode Exit fullscreen mode

take_along_dim() can take one or more elements from a 0D or more D tensor using a 0D or more D tensor as shown below:

*Memos:

  • take_along_dim() can be called both from torch and a tensor.
  • The 2nd argument(Required) is a 0D or more D tensor with torch.
  • The 1st argument(Required) is a 0D or more D tensor with a tensor.
  • The 3rd argument(Optional) is a dimension with torch.
  • The 2nd argument(Optional) is a dimension with a tensor.
  • If a dimension is not set, the size of a returned tensor is a 1D tensor.
  • If a dimension is set, both tensors must be the same D and the returned tensor is the D.
import torch

my_tensor = torch.tensor([[9, 5, 0, 6, 2], [7, 1, 3, 4, 8]])

torch.take_along_dim(my_tensor, torch.tensor(3))
my_tensor.take_along_dim(torch.tensor(3))
torch.gather(my_tensor, torch.tensor(3))
# tensor([6])

torch.take_along_dim(my_tensor, torch.tensor([3, 0, 7, 4]))
torch.take_along_dim(my_tensor, torch.tensor([[3, 0], [7, 4]]))
# tensor([6, 9, 3, 2])

torch.take_along_dim(my_tensor, torch.tensor([[[3, 0], [7, 4]],
                                              [[8, 2], [3, 5]]]))
# tensor([6, 9, 3, 2, 4, 0, 6, 7])

torch.take_along_dim(my_tensor,
                     torch.tensor([[0], [1], [0], [1]]), 0)
torch.take_along_dim(my_tensor,
                     torch.tensor([[0], [1], [0], [1]]), -2)
# tensor([[9, 5, 0, 6, 2],
#         [7, 1, 3, 4, 8],
#         [9, 5, 0, 6, 2],
#         [7, 1, 3, 4, 8]])

torch.take_along_dim(my_tensor, torch.tensor([[0, 0, 0, 0, 0],
                                              [1, 1, 1, 1, 1], 
                                              [0, 1, 0, 1, 0], 
                                              [1, 0, 1, 0, 1]]), 0)
torch.take_along_dim(my_tensor, torch.tensor([[0, 0, 0, 0, 0],
                                              [1, 1, 1, 1, 1], 
                                              [0, 1, 0, 1, 0], 
                                              [1, 0, 1, 0, 1]]), -2)
# tensor([[9, 5, 0, 6, 2],
#         [7, 1, 3, 4, 8],
#         [9, 1, 0, 4, 2],
#         [7, 5, 3, 6, 8]])

torch.take_along_dim(my_tensor, torch.tensor([[3, 0, 4]]), 1)
torch.take_along_dim(my_tensor, torch.tensor([[3, 0, 4]]), -1)
# tensor([[6, 9, 2], [4, 7, 8]])

torch.take_along_dim(my_tensor,
                     torch.tensor([[3, 0, 4], [4, 1, 2]]), 1)
torch.take_along_dim(my_tensor,
                     torch.tensor([[3, 0, 4], [4, 1, 2]]), -1)
# tensor([[6, 9, 2], [8, 1, 3]])
Enter fullscreen mode Exit fullscreen mode

gather() can take one or more elements from a 0D or more D tensor using a 0D or more D tensor as shown below:

*Memos:

  • take_along_dim() can be called both from torch and a tensor.
  • The 2nd argument(Required) is a dimension with torch.
  • The 1st argument(Required) is a dimension with a tensor.
  • The 3rd argument(Required) is a 0D or more D tensor with torch.
  • The 2nd argument(Required) is a 0D or more D tensor with a tensor.
  • Basically, both tensors must be the same D and the returned tensor is the D but the combination of a 0D and 1D tensor or a 1D and 0D tensor is possible, then the returned tensor is the same D as the tensor of the 2nd argument with torch or the 1st argument with a tensor.
import torch

my_tensor = torch.tensor([[9, 5, 0, 6, 2], [7, 1, 3, 4, 8]])

torch.gather(my_tensor, 0, torch.tensor([[0], [1], [0], [1]]))
my_tensor.gather(0, torch.tensor([[0], [1], [0], [1]]))
torch.gather(my_tensor, -2, torch.tensor([[0], [1], [0], [1]]))
# tensor([[9], [7], [9], [7]])

torch.gather(my_tensor, 0, torch.tensor([[0, 0, 0, 0, 0], 
                                         [1, 1, 1, 1, 1], 
                                         [0, 1, 0, 1, 0], 
                                         [1, 0, 1, 0, 1]]))
torch.gather(my_tensor, -2, torch.tensor([[0, 0, 0, 0, 0], 
                                          [1, 1, 1, 1, 1], 
                                          [0, 1, 0, 1, 0], 
                                          [1, 0, 1, 0, 1]]))
# tensor([[9, 5, 0, 6, 2],
#         [7, 1, 3, 4, 8],
#         [9, 1, 0, 4, 2],
#         [7, 5, 3, 6, 8]])

torch.gather(my_tensor, 1, torch.tensor([[3, 0, 4]]))
torch.gather(my_tensor, -1, torch.tensor([[3, 0, 4]]))
# tensor([[6, 9, 2]])

torch.gather(my_tensor, 1,
             torch.tensor([[3, 0, 4], [4, 1, 2]]))
torch.gather(my_tensor, -1,
             torch.tensor([[3, 0, 4], [4, 1, 2]]))
# tensor([[6, 9, 2], [8, 1, 3]])
Enter fullscreen mode Exit fullscreen mode

Top comments (0)