*Memos:
- My post explains split().
- My post explains vsplit().
- My post explains hsplit().
- My post explains dsplit().
- My post explains tensor_split().
- My post explains unbind().
chunk() can get the one or more 1D or more D splitted view tensors of zero or more elements by specifying the number of chunks from the 1D or more D tensor of zero or more elements as shown below:
*Memos:
-
chunk()
can be used with torch or a tensor. - The 1st argument(
input
) withtorch
or using a tensor(Required-Type:tensor
ofint
,float
,complex
orbool
). - The 2nd argument with
torch
or the 1st argument with a tensor ischunks
(Required-Type:int
). - The 3rd argument with
torch
or the 2nd argument with a tensor isdim
(Optional-Default:0
-Type:int
). - The total number of the zero or more elements of one or more returned tensors doesn't change.
- One or more returned tensors keep the dimension of the original tensor.
import torch
my_tensor = torch.tensor([0, 1, 2, 3])
torch.chunk(input=my_tensor, chunks=1)
my_tensor.chunk(chunks=1)
torch.chunk(input=my_tensor, chunks=1, dim=0)
torch.chunk(input=my_tensor, chunks=1, dim=-1)
# (tensor([0, 1, 2, 3]),)
torch.chunk(input=my_tensor, chunks=2)
torch.chunk(input=my_tensor, chunks=2, dim=0)
torch.chunk(input=my_tensor, chunks=2, dim=-1)
torch.chunk(input=my_tensor, chunks=3)
torch.chunk(input=my_tensor, chunks=3, dim=0)
torch.chunk(input=my_tensor, chunks=3, dim=-1)
# (tensor([0, 1]),
# tensor([2, 3]))
torch.chunk(input=my_tensor, chunks=4)
torch.chunk(input=my_tensor, chunks=4, dim=0)
torch.chunk(input=my_tensor, chunks=4, dim=-1)
# (tensor([0]), tensor([1]), tensor([2]), tensor([3]))
my_tensor = torch.tensor([[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11]])
torch.chunk(input=my_tensor, chunks=1)
torch.chunk(input=my_tensor, chunks=1, dim=0)
torch.chunk(input=my_tensor, chunks=1, dim=1)
torch.chunk(input=my_tensor, chunks=1, dim=-1)
torch.chunk(input=my_tensor, chunks=1, dim=-2)
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),)
torch.chunk(input=my_tensor, chunks=2)
torch.chunk(input=my_tensor, chunks=2, dim=0)
torch.chunk(input=my_tensor, chunks=2, dim=-2)
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7]]),
# tensor([[8, 9, 10, 11]]))
torch.chunk(input=my_tensor, chunks=2, dim=1)
torch.chunk(input=my_tensor, chunks=2, dim=-1)
torch.chunk(input=my_tensor, chunks=3, dim=1)
torch.chunk(input=my_tensor, chunks=3, dim=-1)
# (tensor([[0, 1], [4, 5], [8, 9]]),
# tensor([[2, 3], [6, 7], [10, 11]]))
torch.chunk(input=my_tensor, chunks=3)
torch.chunk(input=my_tensor, chunks=3, dim=0)
torch.chunk(input=my_tensor, chunks=3, dim=-2)
torch.chunk(input=my_tensor, chunks=4)
torch.chunk(input=my_tensor, chunks=4, dim=0)
torch.chunk(input=my_tensor, chunks=4, dim=-2)
# (tensor([[0, 1, 2, 3]]),
# tensor([[4, 5, 6, 7]]),
# tensor([[8, 9, 10, 11]]))
torch.chunk(input=my_tensor, chunks=4, dim=1)
torch.chunk(input=my_tensor, chunks=4, dim=-1)
# (tensor([[0], [4], [8]]),
# tensor([[1], [5], [9]]),
# tensor([[2], [6], [10]]),
# tensor([[3], [7], [11]]))
my_tensor = torch.tensor([[[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11]]])
torch.chunk(input=my_tensor, chunks=1)
torch.chunk(input=my_tensor, chunks=1, dim=0)
torch.chunk(input=my_tensor, chunks=1, dim=1)
torch.chunk(input=my_tensor, chunks=1, dim=2)
torch.chunk(input=my_tensor, chunks=1, dim=-1)
torch.chunk(input=my_tensor, chunks=1, dim=-2)
torch.chunk(input=my_tensor, chunks=1, dim=-3)
torch.chunk(input=my_tensor, chunks=2)
torch.chunk(input=my_tensor, chunks=2, dim=0)
torch.chunk(input=my_tensor, chunks=2, dim=-3)
torch.chunk(input=my_tensor, chunks=3)
torch.chunk(input=my_tensor, chunks=3, dim=0)
torch.chunk(input=my_tensor, chunks=3, dim=-3)
torch.chunk(input=my_tensor, chunks=4)
torch.chunk(input=my_tensor, chunks=4, dim=0)
torch.chunk(input=my_tensor, chunks=4, dim=-3)
# (tensor([[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]]),)
torch.chunk(input=my_tensor, chunks=2, dim=1)
torch.chunk(input=my_tensor, chunks=2, dim=-2)
# (tensor([[[0, 1, 2, 3], [4, 5, 6, 7]]]),
# tensor([[[8, 9, 10, 11]]]))
torch.chunk(input=my_tensor, chunks=2, dim=2)
torch.chunk(input=my_tensor, chunks=2, dim=-1)
torch.chunk(input=my_tensor, chunks=3, dim=2)
torch.chunk(input=my_tensor, chunks=3, dim=-1)
# (tensor([[[0, 1], [4, 5], [8, 9]]]),
# tensor([[[2, 3], [6, 7], [10, 11]]]))
torch.chunk(input=my_tensor, chunks=3, dim=1)
torch.chunk(input=my_tensor, chunks=3, dim=-2)
torch.chunk(input=my_tensor, chunks=4, dim=1)
torch.chunk(input=my_tensor, chunks=4, dim=-2)
# (tensor([[[0, 1, 2, 3]]]),
# tensor([[[4, 5, 6, 7]]]),
# tensor([[[8, 9, 10, 11]]]))
torch.chunk(input=my_tensor, chunks=4, dim=2)
torch.chunk(input=my_tensor, chunks=4, dim=-1)
# (tensor([[[0], [4], [8]]]),
# tensor([[[1], [5], [9]]]),
# tensor([[[2], [6], [10]]]),
# tensor([[[3], [7], [11]]]))
my_tensor = torch.tensor([[[0., 1., 2., 3.],
[4., 5., 6., 7.],
[8., 9., 10., 11.]]])
torch.chunk(input=my_tensor, chunks=1)
# (tensor([[[0., 1., 2., 3.],
# [4., 5., 6., 7.],
# [8., 9., 10., 11.]]]),)
my_tensor = torch.tensor([[[0.+0.j, 1.+0.j, 2.+0.j, 3.+0.j],
[4.+0.j, 5.+0.j, 6.+0.j, 7.+0.j],
[8.+0.j, 9.+0.j, 10.+0.j, 11.+0.j]]])
torch.chunk(input=my_tensor, chunks=1)
# (tensor([[[0.+0.j, 1.+0.j, 2.+0.j, 3.+0.j],
# [4.+0.j, 5.+0.j, 6.+0.j, 7.+0.j],
# [8.+0.j, 9.+0.j, 10.+0.j, 11.+0.j]]]),)
my_tensor = torch.tensor([[[True, False, True, False],
[False, True, False, True],
[True, False, True, False]]])
torch.chunk(input=my_tensor, chunks=1)
# (tensor([[[True, False, True, False],
# [False, True, False, True],
# [True, False, True, False]]]),)
Top comments (0)