DEV Community

Super Kai (Kazuya Ito)
Super Kai (Kazuya Ito)

Posted on

kthvalue() and topk() in PyTorch

*Memos:

kthvalue() can get the one or more kth smallest elements and their indices of a 0D or more D tensor as shown below:

*Memos:

  • kthvalue() can be used with torch or a tensor.
  • The 1st argument(tensor of int, float, complex or bool) with torch or using a tensor(tensor of int, float, complex or bool) is input(Required). *complex or bool can only be used for a 0D tensor.
  • The 2nd argument(int) with torch or the 1st argument(int) with a tensor is k(Required).
  • The 3rd argument(int) with torch or the 2nd argument(int) with a tensor is dim(Optional). *If dim is not given, the last dimension of the input is set.
  • The 4th argument(bool) with torch or the 3rd argument(bool) with a tensor is keepdim(Optional-Default:False) which keeps the dimension of the input tensor.
  • If there are the multiple same k th values, one is returned nondeterministically.
import torch

my_tensor = torch.tensor(5)

torch.kthvalue(input=my_tensor, k=1)
my_tensor.kthvalue(k=1)
torch.kthvalue(input=my_tensor, k=1, dim=0)
torch.kthvalue(input=my_tensor, k=1, dim=-1)
torch.kthvalue(input=my_tensor, k=1, dim=0, keepdim=True)
# torch.return_types.kthvalue(
# values=tensor(5),
# indices=tensor(0))

my_tensor = torch.tensor(5.)

torch.kthvalue(input=my_tensor, k=1)
# torch.return_types.kthvalue(
# values=tensor(5.),
# indices=tensor(0))

my_tensor = torch.tensor(5.+0.j)

torch.kthvalue(input=my_tensor, k=1)
# torch.return_types.kthvalue(
# values=tensor(5.+0.j),
# indices=tensor(0))

my_tensor = torch.tensor(True)

torch.kthvalue(input=my_tensor, k=1)
# torch.return_types.kthvalue(
# values=tensor(True),
# indices=tensor(0))

my_tensor = torch.tensor([5, 1, 9, 7, 6, 8, 0, 5])

torch.kthvalue(input=my_tensor, k=3)
torch.kthvalue(input=my_tensor, k=3, dim=0)
torch.kthvalue(input=my_tensor, k=3, dim=-1)
# torch.return_types.kthvalue(
# values=tensor(5),
# indices=tensor(7))

torch.kthvalue(input=my_tensor, k=3, dim=0, keepdim=True)
# torch.return_types.kthvalue(
# values=tensor([5]),
# indices=tensor([7]))

torch.kthvalue(input=my_tensor, k=4)
torch.kthvalue(input=my_tensor, k=4, dim=0)
torch.kthvalue(input=my_tensor, k=4, dim=-1)
# torch.return_types.kthvalue(
# values=tensor(5),
# indices=tensor(0))

torch.kthvalue(input=my_tensor, k=4, dim=0, keepdim=True)
# torch.return_types.kthvalue(
# values=tensor([5]),
# indices=tensor([0]))

my_tensor = torch.tensor([5., 1., 9., 7., 6., 8., 0., 5.])

torch.kthvalue(input=my_tensor, k=3)
# torch.return_types.kthvalue(
# values=tensor(5.),
# indices=tensor(7))

my_tensor = torch.tensor([[5, 1, 9, 7],
                          [6, 8, 0, 5]])
torch.kthvalue(input=my_tensor, k=3)
torch.kthvalue(input=my_tensor, k=3, dim=1)
torch.kthvalue(input=my_tensor, k=3, dim=-1)
# torch.return_types.kthvalue(
# values=tensor([7, 6]),
# indices=tensor([3, 0]))

torch.kthvalue(input=my_tensor, k=3, dim=1, keepdim=True)
# torch.return_types.kthvalue(
# values=tensor([[7], [6]]),
# indices=tensor([[3], [0]]))
Enter fullscreen mode Exit fullscreen mode

topk() can get the zero or more k largest or smallest elements and their indices of a 0D or more D tensor as shown below:

*Memos:

  • topk() can be used with torch or a tensor.
  • The 1st argument(tensor of int, float, complex or bool) with torch or using a tensor(tensor of int, float, complex or bool) is input(Required). *complex or bool can only be used for a 0D(only cpu) tensor.
  • The 2nd argument(int) with torch or the 1st argument(int) with a tensor is k(Required).
  • The 3rd argument(int) with torch or the 2nd argument(int) with a tensor is dim(Optional). *If dim is not given, the last dimension of the input is set.
  • The 4th argument(bool) with torch or the 3rd argument(bool) with a tensor is largest(Optional-Default:True). *True gets zero or more largest elements while False gets zero or more smallest elements.
  • The 5th argument(bool) with torch or the 4th argument(bool) with a tensor is sorted(Optional-Default:True). *Sometimes, a return tensor is sorted with False but sometimes not so make it True if you want to definitely get a sorted tensor.
  • If there are the multiple same k values, one or more ones are returned nondeterministically.
import torch

my_tensor = torch.tensor(5)

torch.topk(input=my_tensor, k=1)
my_tensor.topk(k=1)
torch.topk(input=my_tensor, k=1, dim=0)
torch.topk(input=my_tensor, k=1, dim=-1)
torch.topk(input=my_tensor, k=1, dim=0, largest=False)
torch.topk(input=my_tensor, k=1, dim=0, largest=False, sorted=False)
# torch.return_types.kthvalue(
# values=tensor(5),
# indices=tensor(0))

my_tensor = torch.tensor(5.)

torch.kthvalue(input=my_tensor, k=1)
# torch.return_types.topk(
# values=tensor(5.),
# indices=tensor(0))

my_tensor = torch.tensor(5.+0.j)

torch.topk(input=my_tensor, k=1)
# torch.return_types.topk(
# values=tensor(5.+0.j),
# indices=tensor(0))

my_tensor = torch.tensor(True)

torch.topk(input=my_tensor, k=1)
# torch.return_types.topk(
# values=tensor(True),
# indices=tensor(0))

my_tensor = torch.tensor([5, 1, 9, 7, 6, 8, 0, 5])

torch.topk(input=my_tensor, k=3)
torch.topk(input=my_tensor, k=3, dim=0)
torch.topk(input=my_tensor, k=3, dim=-1)
# torch.return_types.topk(
# values=tensor([9, 8, 7]),
# indices=tensor([2, 5, 3]))

torch.topk(input=my_tensor, k=3, dim=0, largest=False)
# torch.return_types.topk(
# values=tensor([0, 1, 5]),
# indices=tensor([6, 1, 0]))

torch.topk(input=my_tensor, k=3, dim=0, largest=False, sorted=False)
# torch.return_types.topk(
# values=tensor([1, 0, 5]),
# indices=tensor([1, 6, 0]))

torch.topk(input=my_tensor, k=4)
torch.topk(input=my_tensor, k=4, dim=0)
torch.topk(input=my_tensor, k=4, dim=-1)
# torch.return_types.topk(
# values=tensor([9, 8, 7, 6]),
# indices=tensor([2, 5, 3, 4]))

torch.topk(input=my_tensor, k=4, dim=0, largest=False)
# torch.return_types.topk(
# values=tensor([0, 1, 5, 5]),
# indices=tensor([6, 1, 0, 7]))

torch.topk(input=my_tensor, k=4, dim=0, largest=False, sorted=False)
# torch.return_types.topk(
# values=tensor([1, 0, 5, 5]),
# indices=tensor([1, 6, 0, 7]))

my_tensor = torch.tensor([5., 1., 9., 7., 6., 8., 0., 5.])

torch.topk(input=my_tensor, k=3)
# torch.return_types.topk(
# values=tensor([9., 8., 7.]),
# indices=tensor([2, 5, 3]))

my_tensor = torch.tensor([[5, 1, 9, 7],
                          [6, 8, 0, 5]])
torch.topk(input=my_tensor, k=3)
torch.topk(input=my_tensor, k=3, dim=1)
torch.topk(input=my_tensor, k=3, dim=-1)
# torch.return_types.topk(
# values=tensor([[9, 7, 5], [8, 6, 5]]),
# indices=tensor([[2, 3, 0], [1, 0, 3]]))

torch.topk(input=my_tensor, k=3, dim=1, largest=False)
# torch.return_types.topk(
# values=tensor([[1, 5, 7], [0, 5, 6]]),
# indices=tensor([[1, 0, 3], [2, 3, 0]]))

torch.topk(input=my_tensor, k=3, dim=1, largest=False, sorted=False)
# torch.return_types.topk(
# values=tensor([[1, 5, 7], [5, 0, 6]]),
# indices=tensor([[1, 0, 3], [3, 2, 0]]))
Enter fullscreen mode Exit fullscreen mode

Top comments (0)