DEV Community

Super Kai (Kazuya Ito)
Super Kai (Kazuya Ito)

Posted on

split() in PyTorch

*Memos:

split() can split a 1D or more D tensor into 1 or more tensors as shown below.

*Memos:

  • split() can be used with torch and a tensor.
  • The 1st argument(tensor of int, float, complex or bool) with torch or using a tensor(tensor of int, float, complex or bool) is tensor(Required).
  • The 2nd argument(int, tuple of int or list of int) with torch or the 1st argument(int, tuple of int or list of int) with a tensor is split_size_or_sections(Required). *Don't use split_size_or_sections= with a tensor.
  • The 3rd argument(int) with torch or the 2nd argument(int) with a tensor is dim(Optional-Default:0) which is a dimension.
  • The total number of the zero or more elements of the one or more returned tensors changes.
  • The one or more returned tensors keep the dimension of tensor.
import torch

my_tensor = torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]])

torch.split(tensor=my_tensor, split_size_or_sections=1)
my_tensor.split(1)
torch.split(tensor=my_tensor, split_size_or_sections=1, dim=0)
torch.split(tensor=my_tensor, split_size_or_sections=1, dim=-2)
# (tensor([[0, 1, 2, 3]]),
#  tensor([[4, 5, 6, 7]]),
#  tensor([[8, 9, 10, 11]]))

torch.split(tensor=my_tensor, split_size_or_sections=1, dim=1)
torch.split(tensor=my_tensor, split_size_or_sections=1, dim=-1)
# (tensor([[0], [4], [8]]),
#  tensor([[1], [5], [9]]),
#  tensor([[2], [6], [10]]),
#  tensor([[3], [7], [11]]))

torch.split(tensor=my_tensor, split_size_or_sections=2)
torch.split(tensor=my_tensor, split_size_or_sections=2, dim=0)
torch.split(tensor=my_tensor, split_size_or_sections=2, dim=-2)
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7]]),
#  tensor([[8, 9, 10, 11]]))

torch.split(tensor=my_tensor, split_size_or_sections=2, dim=1)
torch.split(tensor=my_tensor, split_size_or_sections=2, dim=-1)
# (tensor([[0, 1], [4, 5], [8, 9]]),
#  tensor([[2, 3], [6, 7], [10, 11]]))

torch.split(tensor=my_tensor, split_size_or_sections=3)
torch.split(tensor=my_tensor, split_size_or_sections=3, dim=0)
torch.split(tensor=my_tensor, split_size_or_sections=3, dim=-2)
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),)

torch.split(tensor=my_tensor, split_size_or_sections=3, dim=1)
torch.split(tensor=my_tensor, split_size_or_sections=3, dim=-1)
# (tensor([[0, 1, 2], [4, 5, 6], [8, 9, 10]]),
#  tensor([[3], [7], [11]]))

torch.split(tensor=my_tensor, split_size_or_sections=(0, 3))
torch.split(tensor=my_tensor, split_size_or_sections=(0, 3), dim=0)
torch.split(tensor=my_tensor, split_size_or_sections=(0, 3), dim=-2)
# (tensor([], size=(0, 4), dtype=torch.int64),
#  tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))

torch.split(tensor=my_tensor, split_size_or_sections=(1, 2))
torch.split(tensor=my_tensor, split_size_or_sections=(1, 2), dim=0)
torch.split(tensor=my_tensor, split_size_or_sections=(1, 2), dim=-2)
# (tensor([[0, 1, 2, 3]]),
#  tensor([[4, 5, 6, 7], [8, 9, 10, 11]]))

torch.split(tensor=my_tensor, split_size_or_sections=(2, 1))
torch.split(tensor=my_tensor, split_size_or_sections=(2, 1), dim=0)
torch.split(tensor=my_tensor, split_size_or_sections=(2, 1), dim=-2)
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7]]),
#  tensor([[8, 9, 10, 11]]))

torch.split(tensor=my_tensor, split_size_or_sections=(3, 0))
torch.split(tensor=my_tensor, split_size_or_sections=(3, 0), dim=0)
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),
#  tensor([], size=(0, 4), dtype=torch.int64))

torch.split(tensor=my_tensor, split_size_or_sections=(1, 1, 1))
torch.split(tensor=my_tensor, split_size_or_sections=(1, 1, 1), dim=0)
torch.split(tensor=my_tensor, split_size_or_sections=(1, 1, 1), dim=-2)
# (tensor([[0, 1, 2, 3]]),
#  tensor([[4, 5, 6, 7]]), 
#  tensor([[ 8,  9, 10, 11]]))

my_tensor = torch.tensor([[0., 1., 2., 3.],
                          [4., 5., 6., 7.],
                          [8., 9., 10., 11.]])
torch.split(tensor=my_tensor, split_size_or_sections=1)
# (tensor([[0., 1., 2., 3.]]),
#  tensor([[4., 5., 6., 7.]]),
#  tensor([[8., 9., 10., 11.]]))

my_tensor = torch.tensor([[0.+0.j, 1.+0.j, 2.+0.j, 3.+0.j],
                          [4.+0.j, 5.+0.j, 6.+0.j, 7.+0.j],
                          [8.+0.j, 9.+0.j, 10.+0.j, 11.+0.j]])
torch.split(tensor=my_tensor, split_size_or_sections=1)
# (tensor([[0.+0.j, 1.+0.j, 2.+0.j, 3.+0.j]]),
#  tensor([[4.+0.j, 5.+0.j, 6.+0.j, 7.+0.j]]),
#  tensor([[8.+0.j, 9.+0.j, 10.+0.j, 11.+0.j]]))

my_tensor = torch.tensor([[True, False, True, False],
                          [False, True, False, True],
                          [True, False, True, False]])
torch.split(tensor=my_tensor, split_size_or_sections=1)
# (tensor([[True, False, True, False]]),
#  tensor([[False, True, False, True]]),
#  tensor([[True, False, True, False]]))
Enter fullscreen mode Exit fullscreen mode

Top comments (0)