take() can take one or more elements from a 0D or more D tensor using a 0D or more D tensor as shown below:
*Memos:
-
take()
can be called both from torch and a tensor. - The 2nd argument(Required) is a 0D or more D tensor with
torch
. - The 1st argument(Required) is a 0D or more D tensor with a tensor.
- The tensor of the 2nd argument with
torch
or the 1st argument with a tensor decides the size of a returned tensor.
import torch
my_tensor = torch.tensor([[9, 5, 0, 6, 2], [7, 1, 3, 4, 8]])
torch.take(my_tensor, torch.tensor(3))
my_tensor.take(torch.tensor(3))
torch.take(my_tensor, torch.tensor(-7))
# tensor(6)
torch.take(my_tensor, torch.tensor([3, 0, 7, 4]))
torch.take(my_tensor, torch.tensor([-7, -10, -3, -6]))
# tensor([6, 9, 3, 2])
torch.take(my_tensor, torch.tensor([[3, 0], [7, 4]]))
torch.take(my_tensor, torch.tensor([[-7, -10], [-3, -6]]))
# tensor([[6, 9], [3, 2]])
torch.take(my_tensor, torch.tensor([[[3, 0], [7, 4]],
[[8, 2], [3, 5]]]))
torch.take(my_tensor, torch.tensor([[[-7, -10], [-3, -6]],
[[-2, -8], [-7, -5]]]))
# tensor([[[6, 9], [3, 2]], [[4, 0], [6, 7]]])
take_along_dim() can take one or more elements from a 0D or more D tensor using a 0D or more D tensor as shown below:
*Memos:
-
take_along_dim()
can be called both fromtorch
and a tensor. - The 2nd argument(Required) is a 0D or more D tensor with
torch
. - The 1st argument(Required) is a 0D or more D tensor with a tensor.
- The 3rd argument(Optional) is a dimension with
torch
. - The 2nd argument(Optional) is a dimension with a tensor.
- If a dimension is not set, the size of a returned tensor is a 1D tensor.
- If a dimension is set, both tensors must be the same D and the returned tensor is the D.
import torch
my_tensor = torch.tensor([[9, 5, 0, 6, 2], [7, 1, 3, 4, 8]])
torch.take_along_dim(my_tensor, torch.tensor(3))
my_tensor.take_along_dim(torch.tensor(3))
torch.gather(my_tensor, torch.tensor(3))
# tensor([6])
torch.take_along_dim(my_tensor, torch.tensor([3, 0, 7, 4]))
torch.take_along_dim(my_tensor, torch.tensor([[3, 0], [7, 4]]))
# tensor([6, 9, 3, 2])
torch.take_along_dim(my_tensor, torch.tensor([[[3, 0], [7, 4]],
[[8, 2], [3, 5]]]))
# tensor([6, 9, 3, 2, 4, 0, 6, 7])
torch.take_along_dim(my_tensor,
torch.tensor([[0], [1], [0], [1]]), 0)
torch.take_along_dim(my_tensor,
torch.tensor([[0], [1], [0], [1]]), -2)
# tensor([[9, 5, 0, 6, 2],
# [7, 1, 3, 4, 8],
# [9, 5, 0, 6, 2],
# [7, 1, 3, 4, 8]])
torch.take_along_dim(my_tensor, torch.tensor([[0, 0, 0, 0, 0],
[1, 1, 1, 1, 1],
[0, 1, 0, 1, 0],
[1, 0, 1, 0, 1]]), 0)
torch.take_along_dim(my_tensor, torch.tensor([[0, 0, 0, 0, 0],
[1, 1, 1, 1, 1],
[0, 1, 0, 1, 0],
[1, 0, 1, 0, 1]]), -2)
# tensor([[9, 5, 0, 6, 2],
# [7, 1, 3, 4, 8],
# [9, 1, 0, 4, 2],
# [7, 5, 3, 6, 8]])
torch.take_along_dim(my_tensor, torch.tensor([[3, 0, 4]]), 1)
torch.take_along_dim(my_tensor, torch.tensor([[3, 0, 4]]), -1)
# tensor([[6, 9, 2], [4, 7, 8]])
torch.take_along_dim(my_tensor,
torch.tensor([[3, 0, 4], [4, 1, 2]]), 1)
torch.take_along_dim(my_tensor,
torch.tensor([[3, 0, 4], [4, 1, 2]]), -1)
# tensor([[6, 9, 2], [8, 1, 3]])
gather() can take one or more elements from a 0D or more D tensor using a 0D or more D tensor as shown below:
*Memos:
-
take_along_dim()
can be called both fromtorch
and a tensor. - The 2nd argument(Required) is a dimension with
torch
. - The 1st argument(Required) is a dimension with a tensor.
- The 3rd argument(Required) is a 0D or more D tensor with
torch
. - The 2nd argument(Required) is a 0D or more D tensor with a tensor.
- Basically, both tensors must be the same D and the returned tensor is the D but the combination of a 0D and 1D tensor or a 1D and 0D tensor is possible, then the returned tensor is the same D as the tensor of the 2nd argument with
torch
or the 1st argument with a tensor.
import torch
my_tensor = torch.tensor([[9, 5, 0, 6, 2], [7, 1, 3, 4, 8]])
torch.gather(my_tensor, 0, torch.tensor([[0], [1], [0], [1]]))
my_tensor.gather(0, torch.tensor([[0], [1], [0], [1]]))
torch.gather(my_tensor, -2, torch.tensor([[0], [1], [0], [1]]))
# tensor([[9], [7], [9], [7]])
torch.gather(my_tensor, 0, torch.tensor([[0, 0, 0, 0, 0],
[1, 1, 1, 1, 1],
[0, 1, 0, 1, 0],
[1, 0, 1, 0, 1]]))
torch.gather(my_tensor, -2, torch.tensor([[0, 0, 0, 0, 0],
[1, 1, 1, 1, 1],
[0, 1, 0, 1, 0],
[1, 0, 1, 0, 1]]))
# tensor([[9, 5, 0, 6, 2],
# [7, 1, 3, 4, 8],
# [9, 1, 0, 4, 2],
# [7, 5, 3, 6, 8]])
torch.gather(my_tensor, 1, torch.tensor([[3, 0, 4]]))
torch.gather(my_tensor, -1, torch.tensor([[3, 0, 4]]))
# tensor([[6, 9, 2]])
torch.gather(my_tensor, 1,
torch.tensor([[3, 0, 4], [4, 1, 2]]))
torch.gather(my_tensor, -1,
torch.tensor([[3, 0, 4], [4, 1, 2]]))
# tensor([[6, 9, 2], [8, 1, 3]])
Top comments (0)