DEV Community

Super Kai (Kazuya Ito)
Super Kai (Kazuya Ito)

Posted on

hsplit() in PyTorch

*Memos:

hsplit() can horizontally split a 1D or more D tensor into 1 or more tensors as shown below:

  • hsplit() can be used with torch and a tensor.
  • The 1st argument(tensor of int, float, complex or bool) with torch or using a tensor(tensor of int, float, complex or bool) is input(Required).
  • The 2nd argument(int) with torch or the 1st argument(int) with a tensor is sections(Required).
  • The 2nd argument(tuple of int or list of int) with torch or the 1st argument(tuple of int or list of int) with a tensor is indices(Required).
  • The total number of the zero or more elements of the one or more returned tensors changes.
  • The one or more returned tensors keep the dimension of the input tensor.
import torch

my_tensor = torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]])
torch.hsplit(input=my_tensor, sections=1)
my_tensor.hsplit(sections=1)
# (tensor([[0, 1, 2, 3],
#          [4, 5, 6, 7],
#          [8, 9, 10, 11]]),)

torch.hsplit(input=my_tensor, sections=2)
# (tensor([[0, 1], [4, 5], [8, 9]]),
#  tensor([[2, 3], [6, 7], [10, 11]]))

torch.hsplit(input=my_tensor, sections=4)
# (tensor([[0], [4], [8]]),
#  tensor([[1], [5], [9]]),
#  tensor([[2], [6], [10]]),
#  tensor([[3], [7], [11]]))

torch.hsplit(input=my_tensor, indices=(0,))
torch.hsplit(input=my_tensor, indices=(-4,))
# (tensor([], size=(3, 0), dtype=torch.int64),
#  tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))

torch.hsplit(input=my_tensor, indices=(1,))
torch.hsplit(input=my_tensor, indices=(-3,))
# (tensor([[0], [4], [8]]),
#  tensor([[1, 2, 3], [5, 6, 7], [9, 10, 11]]))

torch.hsplit(input=my_tensor, indices=(2,))
torch.hsplit(input=my_tensor, indices=(-2,))
# (tensor([[0, 1], [4, 5], [8, 9]]),
#  tensor([[2, 3], [6, 7], [10, 11]]))

torch.hsplit(input=my_tensor, indices=(3,))
torch.hsplit(input=my_tensor, indices=(-1,))
# (tensor([[0, 1, 2], [4, 5, 6], [8, 9, 10]]),
#  tensor([[3], [7], [11]]))

torch.hsplit(input=my_tensor, indices=(4,))
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),
#  tensor([], size=(3, 0), dtype=torch.int64))

torch.hsplit(input=my_tensor, indices=(0, 0))
torch.hsplit(input=my_tensor, indices=(0, -4))
# (tensor([], size=(3, 0), dtype=torch.int64),
#  tensor([], size=(3, 0), dtype=torch.int64),
#  tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))

torch.hsplit(input=my_tensor, indices=(0, 1))
torch.hsplit(input=my_tensor, indices=(0, -3))
# (tensor([], size=(3, 0), dtype=torch.int64),
#  tensor([[0], [4], [8]]),
#  tensor([[1, 2, 3], [5, 6, 7], [9, 10, 11]]))

torch.hsplit(input=my_tensor, indices=(0, 2))
torch.hsplit(input=my_tensor, indices=(0, -2))
# (tensor([], size=(3, 0), dtype=torch.int64),
#  tensor([[0, 1], [4, 5], [8, 9]]),
#  tensor([[2, 3], [6, 7], [10, 11]]))

torch.hsplit(input=my_tensor, indices=(0, 3))
torch.hsplit(input=my_tensor, indices=(0, -1))
# (tensor([], size=(3, 0), dtype=torch.int64),
#  tensor([[0, 1, 2], [4, 5, 6], [8, 9, 10]]),
#  tensor([[3], [7], [11]]))

torch.hsplit(input=my_tensor, indices=(0, 4))
# (tensor([], size=(3, 0), dtype=torch.int64),
#  tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),
#  tensor([], size=(3, 0), dtype=torch.int64))

torch.hsplit(input=my_tensor, indices=(1, 0))
torch.hsplit(input=my_tensor, indices=(1, -4))
# (tensor([[0], [4], [8]]),
#  tensor([], size=(3, 0), dtype=torch.int64),
#  tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))

torch.hsplit(input=my_tensor, indices=(1, 1))
torch.hsplit(input=my_tensor, indices=(1, -3))
# (tensor([[0], [4], [8]]),
#  tensor([], size=(3, 0), dtype=torch.int64),
#  tensor([[1, 2, 3], [5, 6, 7], [9, 10, 11]]))

torch.hsplit(input=my_tensor, indices=(1, 2))
torch.hsplit(input=my_tensor, indices=(1, -2))
# (tensor([[0], [4], [8]]),
#  tensor([[1], [5], [9]]),
#  tensor([[2, 3], [6, 7], [10, 11]]))

torch.hsplit(input=my_tensor, indices=(1, 3))
torch.hsplit(input=my_tensor, indices=(1, -1))
# (tensor([[0], [4], [8]]),
#  tensor([[1, 2], [5, 6], [9, 10]]),
#  tensor([[3], [7], [11]]))

torch.hsplit(input=my_tensor, indices=(1, 4))
# (tensor([[0], [4], [8]]),
#  tensor([[1, 2, 3], [5, 6, 7], [9, 10, 11]]),
#  tensor([], size=(3, 0), dtype=torch.int64))

torch.hsplit(input=my_tensor, indices=(2, 0))
torch.hsplit(input=my_tensor, indices=(2, -4))
# (tensor([[0, 1], [4, 5], [8, 9]]),
#  tensor([], size=(3, 0), dtype=torch.int64),
#  tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))

torch.hsplit(input=my_tensor, indices=(2, 1))
torch.hsplit(input=my_tensor, indices=(2, -3))
# (tensor([[0, 1], [4, 5], [8, 9]]),
#  tensor([], size=(3, 0), dtype=torch.int64),
#  tensor([[1, 2, 3], [5, 6, 7], [9, 10, 11]]))

torch.hsplit(input=my_tensor, indices=(2, 2))
torch.hsplit(input=my_tensor, indices=(2, -2))
# (tensor([[0, 1], [4, 5], [8, 9]]),
#  tensor([], size=(3, 0), dtype=torch.int64),
#  tensor([[2, 3], [6, 7], [10, 11]]))

torch.hsplit(input=my_tensor, indices=(2, 3))
torch.hsplit(input=my_tensor, indices=(2, -1))
# (tensor([[0, 1], [4, 5], [8, 9]]),
#  tensor([[2], [6], [10]]),
#  tensor([[3], [7], [11]]))

torch.hsplit(input=my_tensor, indices=(2, 4))
# (tensor([[0, 1], [4, 5], [8, 9]]),
#  tensor([[2, 3], [6, 7], [10, 11]]),
#  tensor([], size=(3, 0), dtype=torch.int64))

torch.hsplit(input=my_tensor, indices=(3, 0))
torch.hsplit(input=my_tensor, indices=(3, -4))
# (tensor([[0, 1, 2], [4, 5, 6], [8, 9, 10]]),
#  tensor([], size=(3, 0), dtype=torch.int64),
#  tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))

torch.hsplit(input=my_tensor, indices=(3, 1))
torch.hsplit(input=my_tensor, indices=(3, -3))
# (tensor([[0, 1, 2], [4, 5, 6], [8, 9, 10]]),
#  tensor([], size=(3, 0), dtype=torch.int64),
#  tensor([[1, 2, 3], [5, 6, 7], [9, 10, 11]]))

torch.hsplit(input=my_tensor, indices=(3, 2))
torch.hsplit(input=my_tensor, indices=(3, -2))
# (tensor([[0, 1, 2], [4, 5, 6], [8, 9, 10]]),
#  tensor([], size=(3, 0), dtype=torch.int64),
#  tensor([[2, 3], [6, 7], [10, 11]]))

torch.hsplit(input=my_tensor, indices=(3, 3))
torch.hsplit(input=my_tensor, indices=(3, -1))
# (tensor([[0, 1, 2], [4, 5, 6], [8, 9, 10]]),
#  tensor([], size=(3, 0), dtype=torch.int64),
#  tensor([[3], [7], [11]]))

torch.hsplit(input=my_tensor, indices=(3, 4))
# (tensor([[0, 1, 2], [4, 5, 6], [8, 9, 10]]),
#  tensor([[3], [7], [11]]),
#  tensor([], size=(3, 0), dtype=torch.int64))

torch.hsplit(input=my_tensor, indices=(4, 0))
torch.hsplit(input=my_tensor, indices=(4, -4))
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),
#  tensor([], size=(3, 0), dtype=torch.int64),
#  tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))

torch.hsplit(input=my_tensor, indices=(4, 1))
torch.hsplit(input=my_tensor, indices=(4, -3))
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),
#  tensor([], size=(3, 0), dtype=torch.int64),
#  tensor([[1, 2, 3], [5, 6, 7], [9, 10, 11]]))

torch.hsplit(input=my_tensor, indices=(4, 2))
torch.hsplit(input=my_tensor, indices=(4, -2))
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),
#  tensor([], size=(3, 0), dtype=torch.int64),
#  tensor([[2, 3], [6, 7], [10, 11]]))

torch.hsplit(input=my_tensor, indices=(4, 3))
torch.hsplit(input=my_tensor, indices=(4, -1))
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),
#  tensor([], size=(3, 0), dtype=torch.int64),
#  tensor([[3], [7], [11]]))

torch.hsplit(input=my_tensor, indices=(4, 4))
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),
#  tensor([], size=(3, 0), dtype=torch.int64),
#  tensor([], size=(3, 0), dtype=torch.int64))

torch.hsplit(input=my_tensor, indices=(-1, 0))
torch.hsplit(input=my_tensor, indices=(-1, -4))
# (tensor([[0, 1, 2], [4, 5, 6], [8, 9, 10]]),
#  tensor([], size=(3, 0), dtype=torch.int64),
#  tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))

torch.hsplit(input=my_tensor, indices=(-1, 1))
torch.hsplit(input=my_tensor, indices=(-1, -3))
# (tensor([[0, 1, 2], [4, 5, 6], [8, 9, 10]]),
#  tensor([], size=(3, 0), dtype=torch.int64),
#  tensor([[1, 2, 3], [5, 6, 7], [9, 10, 11]]))

torch.hsplit(input=my_tensor, indices=(-1, 2))
torch.hsplit(input=my_tensor, indices=(-1, -2))
# (tensor([[0, 1, 2], [4, 5, 6], [8, 9, 10]]),
#  tensor([], size=(3, 0), dtype=torch.int64),
#  tensor([[2, 3], [6, 7], [10, 11]]))

torch.hsplit(input=my_tensor, indices=(-1, 3))
torch.hsplit(input=my_tensor, indices=(-1, -1))
# (tensor([[0, 1, 2], [4, 5, 6], [8, 9, 10]]),
#  tensor([], size=(3, 0), dtype=torch.int64),
#  tensor([[3], [7], [11]]))

torch.hsplit(input=my_tensor, indices=(-1, 4))
# (tensor([[0, 1,  2], [4, 5, 6], [8, 9, 10]]),
#  tensor([[3], [7], [11]]),
#  tensor([], size=(3, 0), dtype=torch.int64))

torch.hsplit(input=my_tensor, indices=(-1, -1))
# (tensor([[[0, 1, 2, 3], [4, 5, 6, 7]]]),
#  tensor([], size=(1, 0, 4), dtype=torch.int64),
#  tensor([[[8, 9, 10, 11]]]))

torch.hsplit(input=my_tensor, indices=(-1, -2))
# (tensor([[[0, 1, 2, 3], [4, 5, 6, 7]]]),
#  tensor([], size=(1, 0, 4), dtype=torch.int64),
#  tensor([[[4, 5, 6, 7], [8, 9, 10, 11]]]))

torch.hsplit(input=my_tensor, indices=(-2, 0))
torch.hsplit(input=my_tensor, indices=(-2, -4))
# (tensor([[0, 1], [4, 5], [8, 9]]),
#  tensor([], size=(3, 0), dtype=torch.int64),
#  tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))

torch.hsplit(input=my_tensor, indices=(-2, 1))
torch.hsplit(input=my_tensor, indices=(-2, -3))
# (tensor([[0, 1], [4, 5], [8, 9]]),
#  tensor([], size=(3, 0), dtype=torch.int64),
#  tensor([[1, 2, 3], [5, 6, 7], [9, 10, 11]]))

torch.hsplit(input=my_tensor, indices=(-2, 2))
torch.hsplit(input=my_tensor, indices=(-2, -2))
# (tensor([[0, 1], [4, 5], [8, 9]]),
#  tensor([], size=(3, 0), dtype=torch.int64),
#  tensor([[2, 3], [6, 7], [10, 11]]))

torch.hsplit(input=my_tensor, indices=(-2, 3))
torch.hsplit(input=my_tensor, indices=(-2, -1))
# (tensor([[0, 1], [4, 5], [8, 9]]),
#  tensor([[2], [6], [10]]),
#  tensor([[3], [7], [11]]))

torch.hsplit(input=my_tensor, indices=(-2, 4))
# (tensor([[0, 1], [4, 5], [8, 9]]),
#  tensor([[2, 3], [6, 7], [10, 11]]),
#  tensor([], size=(3, 0), dtype=torch.int64))

torch.hsplit(input=my_tensor, indices=(-3, 0))
torch.hsplit(input=my_tensor, indices=(-3, -4))
# (tensor([[0], [4], [8]]),
#  tensor([], size=(3, 0), dtype=torch.int64),
#  tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))

torch.hsplit(input=my_tensor, indices=(-3, 1))
torch.hsplit(input=my_tensor, indices=(-3, -3))
# (tensor([[0], [4], [8]]),
#  tensor([], size=(3, 0), dtype=torch.int64),
#  tensor([[1, 2, 3], [5, 6, 7], [9, 10, 11]]))

torch.hsplit(input=my_tensor, indices=(-3, 2))
torch.hsplit(input=my_tensor, indices=(-3, -2))
# (tensor([[0], [4], [8]]),
#  tensor([[1], [5], [9]]),
#  tensor([[2, 3], [6, 7], [10, 11]]))

torch.hsplit(input=my_tensor, indices=(-3, 3))
torch.hsplit(input=my_tensor, indices=(-3, -1))
# (tensor([[0], [4], [8]]),
#  tensor([[1, 2], [5, 6], [9, 10]]),
#  tensor([[3], [7], [11]]))

torch.hsplit(input=my_tensor, indices=(-3, 4))
# (tensor([[0], [4], [8]]),
#  tensor([[1, 2, 3], [5, 6, 7], [9, 10, 11]]),
#  tensor([], size=(3, 0), dtype=torch.int64))

torch.hsplit(input=my_tensor, indices=(-4, 0))
torch.hsplit(input=my_tensor, indices=(-4, -4))
# (tensor([], size=(3, 0), dtype=torch.int64),
#  tensor([], size=(3, 0), dtype=torch.int64),
#  tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))

torch.hsplit(input=my_tensor, indices=(-4, 1))
torch.hsplit(input=my_tensor, indices=(-4, -3))
# (tensor([], size=(3, 0), dtype=torch.int64),
#  tensor([[0], [4], [8]]),
#  tensor([[1, 2, 3], [5, 6, 7], [9, 10, 11]]))

torch.hsplit(input=my_tensor, indices=(-4, 2))
torch.hsplit(input=my_tensor, indices=(-4, -2))
# (tensor([], size=(3, 0), dtype=torch.int64),
#  tensor([[0, 1], [4, 5], [8, 9]]),
#  tensor([[2, 3], [6, 7], [10, 11]]))

torch.hsplit(input=my_tensor, indices=(-4, 3))
torch.hsplit(input=my_tensor, indices=(-4, -1))
# (tensor([], size=(3, 0), dtype=torch.int64),
#  tensor([[0, 1, 2], [4, 5, 6], [8, 9, 10]]),
#  tensor([[3], [7], [11]]))

torch.hsplit(input=my_tensor, indices=(-4, 4))
# (tensor([], size=(3, 0), dtype=torch.int64),
#  tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),
#  tensor([], size=(3, 0), dtype=torch.int64))

torch.hsplit(input=my_tensor, indices=(0, 0, 0))
torch.hsplit(input=my_tensor, indices=(0, 0, -4))
torch.hsplit(input=my_tensor, indices=(0, -4, 0))
torch.hsplit(input=my_tensor, indices=(0, -4, -4))
# (tensor([], size=(3, 0), dtype=torch.int64),
#  tensor([], size=(3, 0), dtype=torch.int64),
#  tensor([], size=(3, 0), dtype=torch.int64),
#  tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))
etc.

my_tensor = torch.tensor([[0., 1., 2., 3.],
                          [4., 5., 6., 7.],
                          [8., 9., 10., 11.]])
torch.hsplit(input=my_tensor, sections=1)
# (tensor([[0., 1., 2., 3.],
#          [4., 5., 6., 7.],
#          [8., 9., 10., 11.]]),)

my_tensor = torch.tensor([[0.+0.j, 1.+0.j, 2.+0.j, 3.+0.j],
                          [4.+0.j, 5.+0.j, 6.+0.j, 7.+0.j],
                          [8.+0.j, 9.+0.j, 10.+0.j, 11.+0.j]])
torch.hsplit(input=my_tensor, sections=1)
# (tensor([[0.+0.j, 1.+0.j, 2.+0.j, 3.+0.j],
#          [4.+0.j, 5.+0.j, 6.+0.j, 7.+0.j],
#          [8.+0.j, 9.+0.j, 10.+0.j, 11.+0.j]]),)

my_tensor = torch.tensor([[True, False, True, False],
                          [False, True, False, True],
                          [True, False, True, False]])
torch.hsplit(input=my_tensor, sections=1)
# (tensor([[True, False, True, False],
#          [False, True, False, True],
#          [True, False, True, False]]),)
Enter fullscreen mode Exit fullscreen mode

Top comments (0)