DEV Community

Super Kai (Kazuya Ito)
Super Kai (Kazuya Ito)

Posted on

where() and count_nonzero() in PyTorch

*My post explains argwhere() and nonzero().

where() can get a 0D or more D tensor by the zero or more values of two of 0D or more D tensors selected either from input or other, depending on condition as shown below:

*Memos:

  • where() can be used with torch or a tensor.
  • The 1st argument(tensor of bool) with torch or a tensor is condition(Required).
  • The 2nd argument(tensor or scalar of int, flaot, complex or bool) with torch or using a tensor(tensor or scalar of int, float, complex or bool) is input(Required). *Memos:
    • torch must use input with a scalar without condition=, input= and other=.
    • A tensor cannot use input with a scalar.
  • The 3rd argument(tensor or scalar of int, float, complex or bool) with torch or the 2nd argument(tensor or scalar of int, float, complex or bool) with a tensor is other(Required).
  • If condition is True, the value of input is selected otherwise the value of other is selected.
import torch

tensor1 = torch.tensor([[5, 0, 4],
                        [0, 3, 1]])
tensor2 = torch.tensor([60, 70, 80])

torch.where(condition=tensor1 > 2, input=tensor1, other=tensor2)
tensor1.where(condition=tensor1 > 2, other=tensor2)
# tensor([[5, 70, 4],
#         [60, 3, 80]])

torch.where(condition=tensor1 > 2, input=tensor2, other=tensor1)
# tensor([[60, 0, 80],
#         [0, 70, 1]])

torch.where(tensor1 > 2, 10, tensor2)
# tensor([[10, 70, 10],
#         [60, 10, 80]])

torch.where(condition=tensor1 > 2, input=tensor1, other=10)
# tensor([[5, 10, 4],
#         [10, 3, 10]])

torch.where(tensor1 > 2, 10, 20)
# tensor([[10, 20, 10],
#         [20, 10, 20]])

tensor1 = torch.tensor([[5., 0., 4.],
                        [0., 3., 1.]])
tensor2 = torch.tensor([60., 70., 80.])
tensor3 = torch.tensor(True)
torch.where(condition=tensor3, input=tensor1, other=tensor2)
# tensor([[5., 0., 4.],
#         [0., 3., 1.]])

torch.where(tensor3, 5., other=tensor2)
# tensor([5., 5., 5.])

torch.where(condition=tensor3, input=tensor1, other=60.)
# tensor([[5., 0., 4.],
#         [0., 3., 1.]])

torch.where(tensor3, 5., other=60.)
# tensor(5.)

tensor1 = torch.tensor([[5.+0.j, 0.+0.j, 4.+0.j],
                        [0.+0.j, 3.+0.j, 1.+0.j]])
tensor2 = torch.tensor([60.+0.j, 70.+0.j, 80.+0.j])
tensor3 = torch.tensor(False)
torch.where(condition=tensor3, input=tensor1, other=tensor2)
# tensor([[60.+0.j, 70.+0.j, 80.+0.j],
#         [60.+0.j, 70.+0.j, 80.+0.j]])

torch.where(tensor3, 5.+0.j, other=tensor2)
# tensor([60.+0.j, 70.+0.j, 80.+0.j])

torch.where(condition=tensor3, input=tensor1, other=60.+0.j)
# tensor([[60.+0.j, 60.+0.j, 60.+0.j],
#         [60.+0.j, 60.+0.j, 60.+0.j]])

torch.where(tensor3, 5.+0.j, other=60.+0.j)
# tensor(60.+0.j)

tensor1 = torch.tensor([[True, False, True],
                        [False, True, False]])
tensor2 = torch.tensor([False, True, False])
tensor3 = torch.tensor(True)
torch.where(condition=tensor3, input=tensor1, other=tensor2)
# tensor([[True, False, True],
#         [False, True, False]])

torch.where(tensor3, True, other=tensor2)
# tensor([True, True, True])

torch.where(condition=tensor3, input=tensor1, other=False)
# tensor([[True, False, True],
#         [False, True, False]])

torch.where(tensor3, True, other=False)
# tensor(True)

tensor1 = torch.tensor([[[5, 0, 4], [0, 3, 1]],
                          [[0, 7, 0], [0, 6, 8]]])
tensor2 = torch.tensor([60, 70, 80])

torch.where(condition=tensor1 > 2, input=tensor1, other=tensor2)
# tensor([[[5, 70, 4],
#          [60, 3, 80]],
#         [[60, 7, 80],
#          [60, 6, 8]]])

torch.where(condition=tensor1 > 2, input=tensor2, other=tensor1)
# tensor([[[60, 0, 80],
#          [0, 70, 1]],
#         [[0, 70, 0],
#          [0, 70, 80]]])

torch.where(tensor1 > 2, 10, tensor2)
# tensor([[[10, 70, 10],
#          [60, 10, 80]],
#         [[60, 10, 80],
#          [60, 10, 10]]])

torch.where(condition=tensor1 > 2, input=tensor1, other=10)
# tensor([[[5, 10, 4],
#          [10, 3, 10]],
#         [[10, 7, 10],
#          [10, 6, 8]]])

torch.where(tensor1 > 2, 10, 20)
# tensor([[[10, 20, 10],
#          [20, 10, 20]],
#         [[20, 10, 20],
#          [20, 10, 10]]])
Enter fullscreen mode Exit fullscreen mode

count_nonzero() can count the zero or more non-zero values of a 0D or more D tensor as shown below:

*Memos:

  • count_nonzero() can be used with torch or a tensor.
  • The 1st argument(tensor of int, float, complex or bool) with torch or using a tensor(tensor of int, float, complex or bool) is input(Required).
  • The 2nd argument(int, tuple of int or list of int) with torch or the 1st argument(int, tuple of int or list of int) with a tensor is dim(Optional).
import torch

my_tensor = torch.tensor(5)

torch.count_nonzero(input=my_tensor)
my_tensor.count_nonzero()
torch.count_nonzero(input=my_tensor, dim=0)
torch.count_nonzero(input=my_tensor, dim=-1)
torch.count_nonzero(input=my_tensor, dim=(0,))
torch.count_nonzero(input=my_tensor, dim=(-1,))
# tensor(1)

my_tensor = torch.tensor([5, 0, 4, 0, 3, 1])

torch.count_nonzero(input=my_tensor)
torch.count_nonzero(input=my_tensor, dim=0)
torch.count_nonzero(input=my_tensor, dim=-1)
torch.count_nonzero(input=my_tensor, dim=(0,))
torch.count_nonzero(input=my_tensor, dim=(-1,))
# tensor(4)

my_tensor = torch.tensor([5., 0., 4., 0., 3., 1.])

torch.count_nonzero(input=my_tensor)
torch.count_nonzero(input=my_tensor, dim=0)
torch.count_nonzero(input=my_tensor, dim=-1)
torch.count_nonzero(input=my_tensor, dim=(0,))
torch.count_nonzero(input=my_tensor, dim=(-1,))
# tensor(4)

my_tensor = torch.tensor([5.+0.j, 0.+0.j, 4.+0.j, 0.+0.j, 3.+0.j, 1.+0.j])

torch.count_nonzero(input=my_tensor)
# tensor(4)

my_tensor = torch.tensor([True, False, True, False, True, False])

torch.count_nonzero(input=my_tensor)
# tensor(3)

my_tensor = torch.tensor([[5, 0, 4],
                          [0, 3, 1]])
torch.count_nonzero(input=my_tensor)
torch.count_nonzero(input=my_tensor, dim=(0, 1))
torch.count_nonzero(input=my_tensor, dim=(0, -1))
torch.count_nonzero(input=my_tensor, dim=(1, 0))
torch.count_nonzero(input=my_tensor, dim=(1, -2))
torch.count_nonzero(input=my_tensor, dim=(-1, 0))
torch.count_nonzero(input=my_tensor, dim=(-1, -2))
torch.count_nonzero(input=my_tensor, dim=(-2, 1))
torch.count_nonzero(input=my_tensor, dim=(-2, -1))
# tensor(4)

torch.count_nonzero(input=my_tensor, dim=0)
torch.count_nonzero(input=my_tensor, dim=-2)
torch.count_nonzero(input=my_tensor, dim=(0,))
torch.count_nonzero(input=my_tensor, dim=(-2,))
# tensor([1, 1, 2])

torch.count_nonzero(input=my_tensor, dim=1)
torch.count_nonzero(input=my_tensor, dim=-1)
torch.count_nonzero(input=my_tensor, dim=(1,))
torch.count_nonzero(input=my_tensor, dim=(-1,))
# tensor([2, 2])

my_tensor = torch.tensor([[[5, 0, 4], [0, 3, 1]],
                          [[0, 7, 0], [0, 6, 8]]])
torch.count_nonzero(input=my_tensor)
# tensor(7)

torch.count_nonzero(input=my_tensor, dim=0)
torch.count_nonzero(input=my_tensor, dim=-3)
torch.count_nonzero(input=my_tensor, dim=(0,))
torch.count_nonzero(input=my_tensor, dim=(-3,))
# tensor([[1, 1, 1], [0, 2, 2]])

torch.count_nonzero(input=my_tensor, dim=1)
torch.count_nonzero(input=my_tensor, dim=-2)
torch.count_nonzero(input=my_tensor, dim=(1,))
torch.count_nonzero(input=my_tensor, dim=(-2,))
# tensor([[1, 1, 2], [0, 2, 1]])

torch.count_nonzero(input=my_tensor, dim=2)
torch.count_nonzero(input=my_tensor, dim=-1)
torch.count_nonzero(input=my_tensor, dim=(2,))
torch.count_nonzero(input=my_tensor, dim=(-1,))
# tensor([[2, 2], [1, 2]])

torch.count_nonzero(input=my_tensor, dim=(0, 1))
torch.count_nonzero(input=my_tensor, dim=(0, -2))
torch.count_nonzero(input=my_tensor, dim=(1, 0))
torch.count_nonzero(input=my_tensor, dim=(1, -3))
torch.count_nonzero(input=my_tensor, dim=(-2, 0))
torch.count_nonzero(input=my_tensor, dim=(-2, -3))
torch.count_nonzero(input=my_tensor, dim=(-3, 1))
torch.count_nonzero(input=my_tensor, dim=(-3, -2))
# tensor([1, 3, 3])

torch.count_nonzero(input=my_tensor, dim=(0, 2))
torch.count_nonzero(input=my_tensor, dim=(0, -1))
torch.count_nonzero(input=my_tensor, dim=(2, 0))
torch.count_nonzero(input=my_tensor, dim=(2, -3))
torch.count_nonzero(input=my_tensor, dim=(-1, 0))
torch.count_nonzero(input=my_tensor, dim=(-1, -3))
torch.count_nonzero(input=my_tensor, dim=(-3, 2))
torch.count_nonzero(input=my_tensor, dim=(-3, -1))
# tensor([3, 4])

torch.count_nonzero(input=my_tensor, dim=(1, 2))
torch.count_nonzero(input=my_tensor, dim=(1, -1))
torch.count_nonzero(input=my_tensor, dim=(2, 1))
torch.count_nonzero(input=my_tensor, dim=(2, -2))
torch.count_nonzero(input=my_tensor, dim=(-1, 1))
torch.count_nonzero(input=my_tensor, dim=(-1, -2))
torch.count_nonzero(input=my_tensor, dim=(-2, 2))
torch.count_nonzero(input=my_tensor, dim=(-2, -1))
# tensor([4, 3])
Enter fullscreen mode Exit fullscreen mode

Top comments (0)