DEV Community

Super Kai (Kazuya Ito)
Super Kai (Kazuya Ito)

Posted on

masked_select in PyTorch

Buy Me a Coffee

*Memos:

masked_select() can get the 1D tensor of the zero or more elements selected with zero or more masks from the 0D or more D tensor of zero or more elements as shown below:

*Memos:regularization

  • masked_select() can be used with torch or a tensor.
  • The 1st argument(input) with torch or using a tensor(Required-Type:tensor of int, float, complex or bool).
  • The 2nd argument with torch or the 1st argument with a tensor is mask(Required-Type:tensor of bool). *It must be the 0D or more D tensor of zero or more boolean values.
  • There is out argument with torch(Optional-Default:None-Type:tensor): *Memos:
    • out= must be used.
    • My post explains out argument.
import torch

my_tensor = torch.tensor([8, -3, 0, 1, 5, -2])

torch.masked_select(input=my_tensor,
      mask=torch.tensor([False, True, True, False, True, False]))
my_tensor.masked_select(
          mask=torch.tensor([False, True, True, False, True, False]))
# tensor([-3, 0, 5])

torch.masked_select(input=my_tensor, mask=torch.tensor(True))
torch.masked_select(input=my_tensor,
      mask=torch.tensor([True, True, True, True, True, True]))
# tensor([8, -3, 0, 1, 5, -2])

torch.masked_select(input=my_tensor, mask=torch.tensor(False))
torch.masked_select(input=my_tensor,
      mask=torch.tensor([False, False, False, False, False, False]))
# tensor([], dtype=torch.int64)

my_tensor = torch.tensor([[8, -3, 0],
                          [1, 5, -2]])
torch.masked_select(input=my_tensor,
      mask=torch.tensor([[False, True, True],
                         [False, True, False]]))
# tensor([-3, 0, 5])

torch.masked_select(input=my_tensor, mask=torch.tensor(True))
# tensor([8, -3, 0, 1, 5, -2])

torch.masked_select(input=my_tensor, mask=torch.tensor(False))
# tensor([], dtype=torch.int64)

my_tensor = torch.tensor([[[8], [-3], [0]],
                          [[1], [5], [-2]]])
torch.masked_select(input=my_tensor,
      mask=torch.tensor([[[False], [True], [True]],
                         [[False], [True], [False]]]))
# tensor([-3, 0, 5])

torch.masked_select(input=my_tensor, mask=torch.tensor(True))
# tensor([8, -3, 0, 1, 5, -2])

torch.masked_select(input=my_tensor, mask=torch.tensor(False))
# tensor([], dtype=torch.int64)

my_tensor = torch.tensor([[[8.], [-3.], [0.]],
                          [[1.], [5.], [-2.]]])
torch.masked_select(input=my_tensor,
      mask=torch.tensor([[[False], [True], [True]],
                         [[False], [True], [False]]]))
# tensor([-3., 0., 5.])

my_tensor = torch.tensor([[[8.+0.j], [-3.+0.j], [0.+0.j]],
                          [[1.+0.j], [5.+0.j], [-2.+0.j]]])
torch.masked_select(input=my_tensor,
      mask=torch.tensor([[[False], [True], [True]],
                         [[False], [True], [False]]]))
# tensor([-3.+0.j, 0.+0.j, 5.+0.j])

my_tensor = torch.tensor([[[True], [False], [True]],
                          [[False], [True], [False]]])
torch.masked_select(input=my_tensor,
      mask=torch.tensor([[[False], [True], [True]],
                         [[False], [True], [False]]]))
# tensor([False, True, True])
Enter fullscreen mode Exit fullscreen mode

Top comments (0)