DEV Community

Super Kai (Kazuya Ito)
Super Kai (Kazuya Ito)

Posted on

chunk() in PyTorch

*Memos:

chunk() can split a 1D or more D tensor into one or more tensors as shown below:

*Memos:

  • chunk() can be used with torch or a tensor.
  • The 1st argument(tensor of int, float, complex or bool) with torch or using a tensor(tensor of int, float, complex or bool) is input(Required).
  • The 2nd argument(int) with torch or the 1st argument(int) with a tensor is chunks(Required).
  • The 3rd argument(int) with torch or the 2nd argument(int) with a tensor is dim(Optional-Default:0) which is a dimension.
  • The total number of the zero or more elements of the one or more returned tensors doesn't changes.
  • The one or more returned tensors keep the dimension of the input tensor.
import torch

my_tensor = torch.tensor([0, 1, 2, 3])

torch.chunk(input=my_tensor, chunks=1)
my_tensor.chunk(chunks=1)
torch.chunk(input=my_tensor, chunks=1, dim=0)
torch.chunk(input=my_tensor, chunks=1, dim=-1)
# (tensor([0, 1, 2, 3]),)

torch.chunk(input=my_tensor, chunks=2)
torch.chunk(input=my_tensor, chunks=2, dim=0)
torch.chunk(input=my_tensor, chunks=2, dim=-1)
torch.chunk(input=my_tensor, chunks=3)
torch.chunk(input=my_tensor, chunks=3, dim=0)
torch.chunk(input=my_tensor, chunks=3, dim=-1)
# (tensor([0, 1]),
#  tensor([2, 3]))

torch.chunk(input=my_tensor, chunks=4)
torch.chunk(input=my_tensor, chunks=4, dim=0)
torch.chunk(input=my_tensor, chunks=4, dim=-1)
# (tensor([0]), tensor([1]), tensor([2]), tensor([3]))

my_tensor = torch.tensor([[0, 1, 2, 3],
                          [4, 5, 6, 7],
                          [8, 9, 10, 11]])
torch.chunk(input=my_tensor, chunks=1)
torch.chunk(input=my_tensor, chunks=1, dim=0)
torch.chunk(input=my_tensor, chunks=1, dim=1)
torch.chunk(input=my_tensor, chunks=1, dim=-1)
torch.chunk(input=my_tensor, chunks=1, dim=-2)
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),)

torch.chunk(input=my_tensor, chunks=2)
torch.chunk(input=my_tensor, chunks=2, dim=0)
torch.chunk(input=my_tensor, chunks=2, dim=-2)
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7]]),
#  tensor([[8, 9, 10, 11]]))

torch.chunk(input=my_tensor, chunks=2, dim=1)
torch.chunk(input=my_tensor, chunks=2, dim=-1)
torch.chunk(input=my_tensor, chunks=3, dim=1)
torch.chunk(input=my_tensor, chunks=3, dim=-1)
# (tensor([[0, 1], [4, 5], [8, 9]]),
#  tensor([[2, 3], [6, 7], [10, 11]]))

torch.chunk(input=my_tensor, chunks=3)
torch.chunk(input=my_tensor, chunks=3, dim=0)
torch.chunk(input=my_tensor, chunks=3, dim=-2)
torch.chunk(input=my_tensor, chunks=4)
torch.chunk(input=my_tensor, chunks=4, dim=0)
torch.chunk(input=my_tensor, chunks=4, dim=-2)
# (tensor([[0, 1, 2, 3]]),
#  tensor([[4, 5, 6, 7]]),
#  tensor([[8, 9, 10, 11]]))

torch.chunk(input=my_tensor, chunks=4, dim=1)
torch.chunk(input=my_tensor, chunks=4, dim=-1)
# (tensor([[0], [4], [8]]),
#  tensor([[1], [5], [9]]),
#  tensor([[2], [6], [10]]),
#  tensor([[3], [7], [11]]))

my_tensor = torch.tensor([[[0, 1, 2, 3],
                           [4, 5, 6, 7],
                           [8, 9, 10, 11]]])
torch.chunk(input=my_tensor, chunks=1)
torch.chunk(input=my_tensor, chunks=1, dim=0)
torch.chunk(input=my_tensor, chunks=1, dim=1)
torch.chunk(input=my_tensor, chunks=1, dim=2)
torch.chunk(input=my_tensor, chunks=1, dim=-1)
torch.chunk(input=my_tensor, chunks=1, dim=-2)
torch.chunk(input=my_tensor, chunks=1, dim=-3)
torch.chunk(input=my_tensor, chunks=2)
torch.chunk(input=my_tensor, chunks=2, dim=0)
torch.chunk(input=my_tensor, chunks=2, dim=-3)
torch.chunk(input=my_tensor, chunks=3)
torch.chunk(input=my_tensor, chunks=3, dim=0)
torch.chunk(input=my_tensor, chunks=3, dim=-3)
torch.chunk(input=my_tensor, chunks=4)
torch.chunk(input=my_tensor, chunks=4, dim=0)
torch.chunk(input=my_tensor, chunks=4, dim=-3)
# (tensor([[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]]),)

torch.chunk(input=my_tensor, chunks=2, dim=1)
torch.chunk(input=my_tensor, chunks=2, dim=-2)
# (tensor([[[0, 1, 2, 3], [4, 5, 6, 7]]]),
#  tensor([[[8, 9, 10, 11]]]))

torch.chunk(input=my_tensor, chunks=2, dim=2)
torch.chunk(input=my_tensor, chunks=2, dim=-1)
torch.chunk(input=my_tensor, chunks=3, dim=2)
torch.chunk(input=my_tensor, chunks=3, dim=-1)
# (tensor([[[0, 1], [4, 5], [8, 9]]]),
#  tensor([[[2, 3], [6, 7], [10, 11]]]))

torch.chunk(input=my_tensor, chunks=3, dim=1)
torch.chunk(input=my_tensor, chunks=3, dim=-2)
torch.chunk(input=my_tensor, chunks=4, dim=1)
torch.chunk(input=my_tensor, chunks=4, dim=-2)
# (tensor([[[0, 1, 2, 3]]]),
#  tensor([[[4, 5, 6, 7]]]),
#  tensor([[[8, 9, 10, 11]]]))

torch.chunk(input=my_tensor, chunks=4, dim=2)
torch.chunk(input=my_tensor, chunks=4, dim=-1)
# (tensor([[[0], [4], [8]]]),
#  tensor([[[1], [5], [9]]]),
#  tensor([[[2], [6], [10]]]),
#  tensor([[[3], [7], [11]]]))

my_tensor = torch.tensor([[[0., 1., 2., 3.],
                           [4., 5., 6., 7.],
                           [8., 9., 10., 11.]]])
torch.chunk(input=my_tensor, chunks=1)
# (tensor([[[0., 1., 2., 3.],
#           [4., 5., 6., 7.],
#           [8., 9., 10., 11.]]]),)

my_tensor = torch.tensor([[[0.+0.j, 1.+0.j, 2.+0.j, 3.+0.j],
                           [4.+0.j, 5.+0.j, 6.+0.j, 7.+0.j],
                           [8.+0.j, 9.+0.j, 10.+0.j, 11.+0.j]]])
torch.chunk(input=my_tensor, chunks=1)
# (tensor([[[0.+0.j, 1.+0.j, 2.+0.j, 3.+0.j],
#           [4.+0.j, 5.+0.j, 6.+0.j, 7.+0.j],
#           [8.+0.j, 9.+0.j, 10.+0.j, 11.+0.j]]]),)

my_tensor = torch.tensor([[[True, False, True, False],
                           [False, True, False, True],
                           [True, False, True, False]]])
torch.chunk(input=my_tensor, chunks=1)
# (tensor([[[True, False, True, False],
#           [False, True, False, True],
#           [True, False, True, False]]]),)
Enter fullscreen mode Exit fullscreen mode

Top comments (0)