*My post explains take() and take_along_dim().
gather() can get the 0D or more D tensor of zero or more elements using the 0D or more D tensor of zero or more indices from the 0D or more D tensor of zero or more elements as shown below:
*Memos:
-
gather()
can be called withtorch
or a tensor. - The 1st argument(
input
) withtorch
or using a tensor(Required-Type:tensor
ofint
,float
,complex
orbool
). - The 2nd argument with
torch
or the 3rd argument with a tensor isdim
(Required-Type:int
): - The 3rd argument with
torch
or the 2nd argument with a tensor isindices
(Required-Type:tensor
ofint
). - There is
out
argument withtorch
(Optional-Default:None
-Type:tensor
): *Memos:-
out=
must be used. -
My post explains
out
argument.
-
- Basically, an input and
indices
tensor must be the same D and the returned tensor is the D but the combination of a 0D(input) and 1D(indices
) tensor or a 1D(input) and 0D(indices
) tensor is possible, then the returned tensor is the same D asindices
.
import torch
my_tensor = torch.tensor([[10, 11, 12, 13],
[14, 15, 16, 17],
[18, 19, 20, 21]])
torch.gather(input=my_tensor, dim=0,
index=torch.tensor([[0, 0, 0, 0],
[1, 1, 1, 1],
[2, 2, 2, 2]]))
my_tensor.gather(dim=0, index=torch.tensor([[0, 0, 0, 0],
[1, 1, 1, 1],
[2, 2, 2, 2]]))
torch.gather(input=my_tensor, dim=-2,
index=torch.tensor([[0, 0, 0, 0],
[1, 1, 1, 1],
[2, 2, 2, 2]]))
# tensor([[10, 11, 12, 13],
# [14, 15, 16, 17],
# [18, 19, 20, 21]])
torch.gather(input=my_tensor, dim=0,
index=torch.tensor([[0, 2, 1, 0],
[1, 0, 2, 1],
[2, 1, 0, 2]]))
# tensor([[10, 19, 16, 13],
# [14, 11, 20, 17],
# [18, 15, 12, 21]])
torch.gather(input=my_tensor, dim=0,
index=torch.tensor([[0, 2, 1, 0],
[1, 0, 2, 1]]))
# tensor([[10, 19, 16, 13],
# [14, 11, 20, 17]])
torch.gather(input=my_tensor, dim=0,
index=torch.tensor([[0, 2],
[1, 0],
[2, 1],
[0, 2],
[1, 0]]))
# tensor([[10, 19],
# [14, 11],
# [18, 15],
# [10, 19],
# [14, 11]])
torch.gather(input=my_tensor, dim=1,
index=torch.tensor([[0, 1, 2, 3],
[0, 1, 2, 3],
[0, 1, 2, 3]]))
torch.gather(input=my_tensor, dim=-1,
index=torch.tensor([[0, 1, 2, 3],
[0, 1, 2, 3],
[0, 1, 2, 3]]))
# tensor([[10, 11, 12, 13],
# [14, 15, 16, 17],
# [18, 19, 20, 21]])
torch.gather(input=my_tensor, dim=1,
index=torch.tensor([[0, 1, 2, 3],
[3, 0, 1, 2],
[2, 3, 0, 1]]))
# tensor([[10, 11, 12, 13],
# [17, 14, 15, 16],
# [20, 21, 18, 19]])
torch.gather(input=my_tensor, dim=1,
index=torch.tensor([[0, 1],
[3, 0],
[2, 3]]))
# tensor([[10, 11],
# [17, 14],
# [20, 21]])
torch.gather(input=my_tensor, dim=1,
index=torch.tensor([[0, 1, 2, 3, 0, 1],
[3, 0, 1, 2, 3, 0]]))
# tensor([[10, 11, 12, 13, 10, 11],
# [17, 14, 15, 16, 17, 14]])
my_tensor = torch.tensor([[10., 11., 12., 13.],
[14., 15., 16., 17.],
[18., 19., 20., 21.]])
torch.gather(input=my_tensor, dim=0,
index=torch.tensor([[0, 0, 0, 0],
[1, 1, 1, 1],
[2, 2, 2, 2]]))
# tensor([[10., 11., 12., 13.],
# [14., 15., 16., 17.],
# [18., 19., 20., 21.]])
my_tensor = torch.tensor([[10.+0.j, 11.+0.j, 12.+0.j, 13.+0.j],
[14.+0.j, 15.+0.j, 16.+0.j, 17.+0.j],
[18.+0.j, 19.+0.j, 20.+0.j, 21.+0.j]])
torch.gather(input=my_tensor, dim=0,
index=torch.tensor([[0, 0, 0, 0],
[1, 1, 1, 1],
[2, 2, 2, 2]]))
# tensor([[10.+0.j, 11.+0.j, 12.+0.j, 13.+0.j],
# [14.+0.j, 15.+0.j, 16.+0.j, 17.+0.j],
# [18.+0.j, 19.+0.j, 20.+0.j, 21.+0.j]])
my_tensor = torch.tensor([[True, False, True, False],
[False, True, False, True],
[True, False, True, False]])
torch.gather(input=my_tensor, dim=0,
index=torch.tensor([[0, 0, 0, 0],
[1, 1, 1, 1],
[2, 2, 2, 2]]))
# tensor([[True, False, True, False],
# [False, True, False, True],
# [True, False, True, False]])
my_tensor = torch.tensor([10, 11, 12, 13])
torch.gather(input=my_tensor, dim=0, index=torch.tensor(2))
# tensor(12)
my_tensor = torch.tensor(2)
torch.gather(input=my_tensor, dim=0, index=torch.tensor([0, 0, 0, 0]))
# tensor([2, 2, 2, 2])
Top comments (0)