*Memos:
- My post explains select().
- My post explains masked_select().
index_select() can get the 0D or more D tensor of the zero or more elements selected with zero or more indices, not removing one dimension from the 0D or more D tensor of zero or more elements as shown below:
*Memos:regularization
-
index_select()
can be used with torch or a tensor. - The 1st argument(
input
) withtorch
or using a tensor(Required-Type:tensor
ofint
,float
,complex
orbool
). *It must be the 0D or more D tensor of zero or more elements. - The 2nd argument with
torch
or the 1st argument with a tensor isdim
(Required-Type:int
). - The 3rd argument with
torch
or the 2nd argument with a tensor isindex
(Required-Type:tensor
ofint
). *It must be the 0D or 1D tensor of zero or more integers. - There is
out
argument withtorch
(Optional-Default:None
-Type:tensor
): *Memos:-
out=
must be used. -
My post explains
out
argument.
-
import torch
my_tensor = torch.tensor([8, -3, 0, 1, 5, -2, -1, 4])
torch.index_select(input=my_tensor, dim=0, index=torch.tensor(4))
my_tensor.index_select(dim=0, index=torch.tensor(4))
torch.index_select(input=my_tensor, dim=-1, index=torch.tensor(4))
# tensor([5])
torch.index_select(input=my_tensor, dim=0, index=torch.tensor([5, 2, 0, 7]))
torch.index_select(input=my_tensor, dim=-1, index=torch.tensor([5, 2, 0, 7]))
# tensor([-2, 0, 8, 4])
my_tensor = torch.tensor([[8, -3, 0, 1],
[5, -2, -1, 4]])
torch.index_select(input=my_tensor, dim=0, index=torch.tensor(1))
torch.index_select(input=my_tensor, dim=0, index=torch.tensor([1]))
torch.index_select(input=my_tensor, dim=-2, index=torch.tensor(1))
torch.index_select(input=my_tensor, dim=-2, index=torch.tensor([1]))
# tensor([[5, -2, -1, 4]])
torch.index_select(input=my_tensor, dim=0, index=torch.tensor([1, 0, 0, 1]))
torch.index_select(input=my_tensor, dim=-2, index=torch.tensor([1, 0, 0, 1]))
# tensor([[5, -2, -1, 4],
# [8, -3, 0, 1],
# [8, -3, 0, 1],
# [5, -2, -1, 4]])
torch.index_select(input=my_tensor, dim=1, index=torch.tensor([3, 1, 2]))
torch.index_select(input=my_tensor, dim=-1, index=torch.tensor([3, 1, 2]))
# tensor([[1, -3, 0],
# [4, -2, -1]])
my_tensor = torch.tensor([[[8, -3], [0, 1]],
[[5, -2], [-1, 4]]])
torch.index_select(input=my_tensor, dim=2, index=torch.tensor(1))
torch.index_select(input=my_tensor, dim=2, index=torch.tensor([1]))
torch.index_select(input=my_tensor, dim=-1, index=torch.tensor(1))
torch.index_select(input=my_tensor, dim=-1, index=torch.tensor([1]))
# tensor([[[-3], [1]],
# [[-2], [4]]])
my_tensor = torch.tensor([[[8., -3.], [0., 1.]],
[[5., -2.], [-1., 4.]]])
torch.index_select(input=my_tensor, dim=2, index=torch.tensor(1))
# tensor([[[-3.], [1.]],
# [[-2.], [4.]]])
my_tensor = torch.tensor([[[8.+0.j, -3.+0.j], [0.+0.j, 1.+0.j]],
[[5.+0.j, -2.+0.j], [-1.+0.j, 4.+0.j]]])
torch.index_select(input=my_tensor, dim=2, index=torch.tensor(1))
# tensor([[[-3.+0.j], [1.+0.j]],
# [[-2.+0.j], [4.+0.j]]])
my_tensor = torch.tensor([[[True, False], [True, False]],
[[False, True], [False, True]]])
torch.index_select(input=my_tensor, dim=2, index=torch.tensor(1))
# tensor([[[False], [False]],
# [[True], [True]]])
Top comments (0)