*Memos:
split() can split a 1D or more D tensor into 1 or more tensors as shown below.
*Memos:
-
split()
can be used with torch and a tensor. - The tensor of zero or more integers, floating-point numbers, complex numbers or boolean values can be used.
- The 2nd argument(
int
,tuple
ofint
orlist
ofint
) withtorch
or the 1st argument(int
,tuple
ofint
orlist
ofint
) with a tensor issplit_size_or_sections
(Required). *Don't usesplit_size_or_sections=
with a tensor. - The 3rd argument(
int
) withtorch
or the 2nd argument(int
) with a tensor isdim
(Optional-Default:0
) which is a dimension. - The total number of the zero or more elements of the one or more returned tensors changes.
- The one or more returned tensors keep the dimension.
import torch
my_tensor = torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]])
torch.split(my_tensor, split_size_or_sections=1)
my_tensor.split(1)
torch.split(my_tensor, split_size_or_sections=1, dim=0)
my_tensor.split(1, dim=0)
torch.split(my_tensor, split_size_or_sections=1, dim=-2)
my_tensor.split(1, dim=-2)
# (tensor([[0, 1, 2, 3]]),
# tensor([[4, 5, 6, 7]]),
# tensor([[8, 9, 10, 11]]))
torch.split(my_tensor, split_size_or_sections=1, dim=1)
torch.split(my_tensor, split_size_or_sections=1, dim=-1)
# (tensor([[0], [4], [8]]),
# tensor([[1], [5], [9]]),
# tensor([[2], [6], [10]]),
# tensor([[3], [7], [11]]))
torch.split(my_tensor, split_size_or_sections=2)
torch.split(my_tensor, split_size_or_sections=2, dim=0)
torch.split(my_tensor, split_size_or_sections=2, dim=-2)
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7]]),
# tensor([[8, 9, 10, 11]]))
torch.split(my_tensor, split_size_or_sections=2, dim=1)
torch.split(my_tensor, split_size_or_sections=2, dim=-1)
# (tensor([[0, 1], [4, 5], [8, 9]]),
# tensor([[2, 3], [6, 7], [10, 11]]))
torch.split(my_tensor, split_size_or_sections=3)
torch.split(my_tensor, split_size_or_sections=3, dim=0)
torch.split(my_tensor, split_size_or_sections=3, dim=-2)
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),)
torch.split(my_tensor, split_size_or_sections=3, dim=1)
torch.split(my_tensor, split_size_or_sections=3, dim=-1)
# (tensor([[0, 1, 2], [4, 5, 6], [8, 9, 10]]),
# tensor([[3], [7], [11]]))
torch.split(my_tensor, split_size_or_sections=(0, 3))
torch.split(my_tensor, split_size_or_sections=(0, 3), dim=0)
torch.split(my_tensor, split_size_or_sections=(0, 3), dim=-2)
# (tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))
torch.split(my_tensor, split_size_or_sections=(1, 2))
torch.split(my_tensor, split_size_or_sections=(1, 2), dim=0)
torch.split(my_tensor, split_size_or_sections=(1, 2), dim=-2)
# (tensor([[0, 1, 2, 3]]),
# tensor([[4, 5, 6, 7], [8, 9, 10, 11]]))
torch.split(my_tensor, split_size_or_sections=(2, 1))
torch.split(my_tensor, split_size_or_sections=(2, 1), dim=0)
torch.split(my_tensor, split_size_or_sections=(2, 1), dim=-2)
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7]]),
# tensor([[8, 9, 10, 11]]))
torch.split(my_tensor, split_size_or_sections=(3, 0))
torch.split(my_tensor, split_size_or_sections=(3, 0), dim=0)
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),
# tensor([], size=(0, 4), dtype=torch.int64))
torch.split(my_tensor, split_size_or_sections=(1, 1, 1))
torch.split(my_tensor, split_size_or_sections=(1, 1, 1), dim=0)
torch.split(my_tensor, split_size_or_sections=(1, 1, 1), dim=-2)
# (tensor([[0, 1, 2, 3]]),
# tensor([[4, 5, 6, 7]]),
# tensor([[ 8, 9, 10, 11]]))
my_tensor = torch.tensor([[False, True, 2., 3.],
[4., 5., 6., 7+0j],
[8+0j, 9+0j, 10+0j, 11+0j]])
torch.split(my_tensor, split_size_or_sections=1)
torch.split(my_tensor, split_size_or_sections=1, dim=0)
torch.split(my_tensor, split_size_or_sections=1, dim=-2)
# (tensor([[0.+0.j, 1.+0.j, 2.+0.j, 3.+0.j]]),
# tensor([[4.+0.j, 5.+0.j, 6.+0.j, 7.+0.j]]),
# tensor([[8.+0.j, 9.+0.j, 10.+0.j, 11.+0.j]]))
hsplit() can horizontally split a 1D or more D tensor into 1 or more tensors as shown below:
-
hsplit()
can be used withtorch
and a tensor. - The tensor of zero or more integers, floating-point numbers, complex numbers or boolean values can be used.
- The 2nd argument with
torch
or the 1st argument with a tensor isindices_or_sections
(Required). *Don't useindices_or_sections=
withtorch
or a tensor. - The total number of the zero or more elements of the one or more returned tensors changes.
- The one or more returned tensors keep the dimension.
import torch
my_tensor = torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]])
torch.hsplit(my_tensor, 1)
my_tensor.hsplit(1)
# (tensor([[0, 1, 2, 3],
# [4, 5, 6, 7],
# [8, 9, 10, 11]]),)
torch.hsplit(my_tensor, 2)
# (tensor([[0, 1], [4, 5], [8, 9]]),
# tensor([[2, 3], [6, 7], [10, 11]]))
torch.hsplit(my_tensor, 4)
# (tensor([[0], [4], [8]]),
# tensor([[1], [5], [9]]),
# tensor([[2], [6], [10]]),
# tensor([[3], [7], [11]]))
torch.hsplit(my_tensor, (0,))
torch.hsplit(my_tensor, (-4,))
# (tensor([], size=(3, 0), dtype=torch.int64),
# tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))
torch.hsplit(my_tensor, (1,))
torch.hsplit(my_tensor, (-3,))
# (tensor([[0], [4], [8]]),
# tensor([[1, 2, 3], [5, 6, 7], [9, 10, 11]]))
torch.hsplit(my_tensor, (2,))
torch.hsplit(my_tensor, (-2,))
# (tensor([[0, 1], [4, 5], [8, 9]]),
# tensor([[2, 3], [6, 7], [10, 11]]))
torch.hsplit(my_tensor, (3,))
torch.hsplit(my_tensor, (-1,))
# (tensor([[0, 1, 2], [4, 5, 6], [8, 9, 10]]),
# tensor([[3], [7], [11]]))
torch.hsplit(my_tensor, (4,))
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),
# tensor([], size=(3, 0), dtype=torch.int64))
torch.hsplit(my_tensor, (0, 0))
torch.hsplit(my_tensor, (0, -4))
# (tensor([], size=(3, 0), dtype=torch.int64),
# tensor([], size=(3, 0), dtype=torch.int64),
# tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))
torch.hsplit(my_tensor, (0, 1))
torch.hsplit(my_tensor, (0, -3))
# (tensor([], size=(3, 0), dtype=torch.int64),
# tensor([[0], [4], [8]]),
# tensor([[1, 2, 3], [5, 6, 7], [9, 10, 11]]))
torch.hsplit(my_tensor, (0, 2))
torch.hsplit(my_tensor, (0, -2))
# (tensor([], size=(3, 0), dtype=torch.int64),
# tensor([[0, 1], [4, 5], [8, 9]]),
# tensor([[2, 3], [6, 7], [10, 11]]))
torch.hsplit(my_tensor, (0, 3))
torch.hsplit(my_tensor, (0, -1))
# (tensor([], size=(3, 0), dtype=torch.int64),
# tensor([[0, 1, 2], [4, 5, 6], [8, 9, 10]]),
# tensor([[3], [7], [11]]))
torch.hsplit(my_tensor, (0, 4))
# (tensor([], size=(3, 0), dtype=torch.int64),
# tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),
# tensor([], size=(3, 0), dtype=torch.int64))
torch.hsplit(my_tensor, (1, 0))
torch.hsplit(my_tensor, (1, -4))
# (tensor([[0], [4], [8]]),
# tensor([], size=(3, 0), dtype=torch.int64),
# tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))
torch.hsplit(my_tensor, (1, 1))
torch.hsplit(my_tensor, (1, -3))
# (tensor([[0], [4], [8]]),
# tensor([], size=(3, 0), dtype=torch.int64),
# tensor([[1, 2, 3], [5, 6, 7], [9, 10, 11]]))
torch.hsplit(my_tensor, (1, 2))
torch.hsplit(my_tensor, (1, -2))
# (tensor([[0], [4], [8]]),
# tensor([[1], [5], [9]]),
# tensor([[2, 3], [6, 7], [10, 11]]))
torch.hsplit(my_tensor, (1, 3))
torch.hsplit(my_tensor, (1, -1))
# (tensor([[0], [4], [8]]),
# tensor([[1, 2], [5, 6], [9, 10]]),
# tensor([[3], [7], [11]]))
torch.hsplit(my_tensor, (1, 4))
# (tensor([[0], [4], [8]]),
# tensor([[1, 2, 3], [5, 6, 7], [9, 10, 11]]),
# tensor([], size=(3, 0), dtype=torch.int64))
torch.hsplit(my_tensor, (2, 0))
torch.hsplit(my_tensor, (2, -4))
# (tensor([[0, 1], [4, 5], [8, 9]]),
# tensor([], size=(3, 0), dtype=torch.int64),
# tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))
torch.hsplit(my_tensor, (2, 1))
torch.hsplit(my_tensor, (2, -3))
# (tensor([[0, 1], [4, 5], [8, 9]]),
# tensor([], size=(3, 0), dtype=torch.int64),
# tensor([[1, 2, 3], [5, 6, 7], [9, 10, 11]]))
torch.hsplit(my_tensor, (2, 2))
torch.hsplit(my_tensor, (2, -2))
# (tensor([[0, 1], [4, 5], [8, 9]]),
# tensor([], size=(3, 0), dtype=torch.int64),
# tensor([[2, 3], [6, 7], [10, 11]]))
torch.hsplit(my_tensor, (2, 3))
torch.hsplit(my_tensor, (2, -1))
# (tensor([[0, 1], [4, 5], [8, 9]]),
# tensor([[2], [6], [10]]),
# tensor([[3], [7], [11]]))
torch.hsplit(my_tensor, (2, 4))
# (tensor([[0, 1], [4, 5], [8, 9]]),
# tensor([[2, 3], [6, 7], [10, 11]]),
# tensor([], size=(3, 0), dtype=torch.int64))
torch.hsplit(my_tensor, (3, 0))
torch.hsplit(my_tensor, (3, -4))
# (tensor([[0, 1, 2], [4, 5, 6], [8, 9, 10]]),
# tensor([], size=(3, 0), dtype=torch.int64),
# tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))
torch.hsplit(my_tensor, (3, 1))
torch.hsplit(my_tensor, (3, -3))
# (tensor([[0, 1, 2], [4, 5, 6], [8, 9, 10]]),
# tensor([], size=(3, 0), dtype=torch.int64),
# tensor([[1, 2, 3], [5, 6, 7], [9, 10, 11]]))
torch.hsplit(my_tensor, (3, 2))
torch.hsplit(my_tensor, (3, -2))
# (tensor([[0, 1, 2], [4, 5, 6], [8, 9, 10]]),
# tensor([], size=(3, 0), dtype=torch.int64),
# tensor([[2, 3], [6, 7], [10, 11]]))
torch.hsplit(my_tensor, (3, 3))
torch.hsplit(my_tensor, (3, -1))
# (tensor([[0, 1, 2], [4, 5, 6], [8, 9, 10]]),
# tensor([], size=(3, 0), dtype=torch.int64),
# tensor([[3], [7], [11]]))
torch.hsplit(my_tensor, (3, 4))
# (tensor([[0, 1, 2], [4, 5, 6], [8, 9, 10]]),
# tensor([[3], [7], [11]]),
# tensor([], size=(3, 0), dtype=torch.int64))
torch.hsplit(my_tensor, (4, 0))
torch.hsplit(my_tensor, (4, -4))
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),
# tensor([], size=(3, 0), dtype=torch.int64),
# tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))
torch.hsplit(my_tensor, (4, 1))
torch.hsplit(my_tensor, (4, -3))
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),
# tensor([], size=(3, 0), dtype=torch.int64),
# tensor([[1, 2, 3], [5, 6, 7], [9, 10, 11]]))
torch.hsplit(my_tensor, (4, 2))
torch.hsplit(my_tensor, (4, -2))
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),
# tensor([], size=(3, 0), dtype=torch.int64),
# tensor([[2, 3], [6, 7], [10, 11]]))
torch.hsplit(my_tensor, (4, 3))
torch.hsplit(my_tensor, (4, -1))
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),
# tensor([], size=(3, 0), dtype=torch.int64),
# tensor([[3], [7], [11]]))
torch.hsplit(my_tensor, (4, 4))
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),
# tensor([], size=(3, 0), dtype=torch.int64),
# tensor([], size=(3, 0), dtype=torch.int64))
torch.hsplit(my_tensor, (-1, 0))
torch.hsplit(my_tensor, (-1, -4))
# (tensor([[0, 1, 2], [4, 5, 6], [8, 9, 10]]),
# tensor([], size=(3, 0), dtype=torch.int64),
# tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))
torch.hsplit(my_tensor, (-1, 1))
torch.hsplit(my_tensor, (-1, -3))
# (tensor([[0, 1, 2], [4, 5, 6], [8, 9, 10]]),
# tensor([], size=(3, 0), dtype=torch.int64),
# tensor([[1, 2, 3], [5, 6, 7], [9, 10, 11]]))
torch.hsplit(my_tensor, (-1, 2))
torch.hsplit(my_tensor, (-1, -2))
# (tensor([[0, 1, 2], [4, 5, 6], [8, 9, 10]]),
# tensor([], size=(3, 0), dtype=torch.int64),
# tensor([[2, 3], [6, 7], [10, 11]]))
torch.hsplit(my_tensor, (-1, 3))
torch.hsplit(my_tensor, (-1, -1))
# (tensor([[0, 1, 2], [4, 5, 6], [8, 9, 10]]),
# tensor([], size=(3, 0), dtype=torch.int64),
# tensor([[3], [7], [11]]))
torch.hsplit(my_tensor, (-1, 4))
# (tensor([[0, 1, 2], [4, 5, 6], [8, 9, 10]]),
# tensor([[3], [7], [11]]),
# tensor([], size=(3, 0), dtype=torch.int64))
torch.hsplit(my_tensor, (-1, -1))
# (tensor([[[0, 1, 2, 3], [4, 5, 6, 7]]]),
# tensor([], size=(1, 0, 4), dtype=torch.int64),
# tensor([[[8, 9, 10, 11]]]))
torch.hsplit(my_tensor, (-1, -2))
# (tensor([[[0, 1, 2, 3], [4, 5, 6, 7]]]),
# tensor([], size=(1, 0, 4), dtype=torch.int64),
# tensor([[[4, 5, 6, 7], [8, 9, 10, 11]]]))
torch.hsplit(my_tensor, (-2, 0))
torch.hsplit(my_tensor, (-2, -4))
# (tensor([[0, 1], [4, 5], [8, 9]]),
# tensor([], size=(3, 0), dtype=torch.int64),
# tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))
torch.hsplit(my_tensor, (-2, 1))
torch.hsplit(my_tensor, (-2, -3))
# (tensor([[0, 1], [4, 5], [8, 9]]),
# tensor([], size=(3, 0), dtype=torch.int64),
# tensor([[1, 2, 3], [5, 6, 7], [9, 10, 11]]))
torch.hsplit(my_tensor, (-2, 2))
torch.hsplit(my_tensor, (-2, -2))
# (tensor([[0, 1], [4, 5], [8, 9]]),
# tensor([], size=(3, 0), dtype=torch.int64),
# tensor([[2, 3], [6, 7], [10, 11]]))
torch.hsplit(my_tensor, (-2, 3))
torch.hsplit(my_tensor, (-2, -1))
# (tensor([[0, 1], [4, 5], [8, 9]]),
# tensor([[2], [6], [10]]),
# tensor([[3], [7], [11]]))
torch.hsplit(my_tensor, (-2, 4))
# (tensor([[0, 1], [4, 5], [8, 9]]),
# tensor([[2, 3], [6, 7], [10, 11]]),
# tensor([], size=(3, 0), dtype=torch.int64))
torch.hsplit(my_tensor, (-3, 0))
torch.hsplit(my_tensor, (-3, -4))
# (tensor([[0], [4], [8]]),
# tensor([], size=(3, 0), dtype=torch.int64),
# tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))
torch.hsplit(my_tensor, (-3, 1))
torch.hsplit(my_tensor, (-3, -3))
# (tensor([[0], [4], [8]]),
# tensor([], size=(3, 0), dtype=torch.int64),
# tensor([[1, 2, 3], [5, 6, 7], [9, 10, 11]]))
torch.hsplit(my_tensor, (-3, 2))
torch.hsplit(my_tensor, (-3, -2))
# (tensor([[0], [4], [8]]),
# tensor([[1], [5], [9]]),
# tensor([[2, 3], [6, 7], [10, 11]]))
torch.hsplit(my_tensor, (-3, 3))
torch.hsplit(my_tensor, (-3, -1))
# (tensor([[0], [4], [8]]),
# tensor([[1, 2], [5, 6], [9, 10]]),
# tensor([[3], [7], [11]]))
torch.hsplit(my_tensor, (-3, 4))
# (tensor([[0], [4], [8]]),
# tensor([[1, 2, 3], [5, 6, 7], [9, 10, 11]]),
# tensor([], size=(3, 0), dtype=torch.int64))
torch.hsplit(my_tensor, (-4, 0))
torch.hsplit(my_tensor, (-4, -4))
# (tensor([], size=(3, 0), dtype=torch.int64),
# tensor([], size=(3, 0), dtype=torch.int64),
# tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))
torch.hsplit(my_tensor, (-4, 1))
torch.hsplit(my_tensor, (-4, -3))
# (tensor([], size=(3, 0), dtype=torch.int64),
# tensor([[0], [4], [8]]),
# tensor([[1, 2, 3], [5, 6, 7], [9, 10, 11]]))
torch.hsplit(my_tensor, (-4, 2))
torch.hsplit(my_tensor, (-4, -2))
# (tensor([], size=(3, 0), dtype=torch.int64),
# tensor([[0, 1], [4, 5], [8, 9]]),
# tensor([[2, 3], [6, 7], [10, 11]]))
torch.hsplit(my_tensor, (-4, 3))
torch.hsplit(my_tensor, (-4, -1))
# (tensor([], size=(3, 0), dtype=torch.int64),
# tensor([[0, 1, 2], [4, 5, 6], [8, 9, 10]]),
# tensor([[3], [7], [11]]))
torch.hsplit(my_tensor, (-4, 4))
# (tensor([], size=(3, 0), dtype=torch.int64),
# tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),
# tensor([], size=(3, 0), dtype=torch.int64))
torch.hsplit(my_tensor, (0, 0, 0))
torch.hsplit(my_tensor, (0, 0, -4))
torch.hsplit(my_tensor, (0, -4, 0))
torch.hsplit(my_tensor, (0, -4, -4))
# (tensor([], size=(3, 0), dtype=torch.int64),
# tensor([], size=(3, 0), dtype=torch.int64),
# tensor([], size=(3, 0), dtype=torch.int64),
# tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))
etc.
my_tensor = torch.tensor([[False, True, 2., 3.],
[4., 5., 6., 7+0j],
[8+0j, 9+0j, 10+0j, 11+0j]])
torch.hsplit(my_tensor, 1)
# (tensor([[0.+0.j, 1.+0.j, 2.+0.j, 3.+0.j],
# [4.+0.j, 5.+0.j, 6.+0.j, 7.+0.j],
# [8.+0.j, 9.+0.j, 10.+0.j, 11.+0.j]]),)
vsplit() can vertically splits a 2D or more D tensor into 1 or more tensors as shown below:
*Memos:
-
vsplit()
can be used withtorch
and a tensor. - The tensor of zero or more integers, floating-point numbers, complex numbers or boolean values can be used.
- The 2nd argument(
int
,tuple
ofint
orlist
ofint
) withtorch
or the 1st argument(int
,tuple
ofint
orlist
ofint
) with a tensor isindices_or_sections
(Required). *Don't useindices_or_sections=
withtorch
or a tensor. - The total number of the zero or more elements of the one or more returned tensors changes.
- The one or more returned tensors keep the dimension.
import torch
my_tensor = torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]])
torch.vsplit(my_tensor, 1)
my_tensor.vsplit(1)
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),)
torch.vsplit(my_tensor, 3)
# (tensor([[0, 1, 2, 3]]),
# tensor([[4, 5, 6, 7]]),
# tensor([[8, 9, 10, 11]]))
torch.vsplit(my_tensor, (0,))
torch.vsplit(my_tensor, (-3,))
# (tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))
torch.vsplit(my_tensor, (1,))
torch.vsplit(my_tensor, (-2,))
# (tensor([[0, 1, 2, 3]]),
# tensor([[4, 5, 6, 7], [8, 9, 10, 11]]))
torch.vsplit(my_tensor, (2,))
torch.vsplit(my_tensor, (-1,))
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7]]),
# tensor([[8, 9, 10, 11]]))
torch.vsplit(my_tensor, (3,))
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),
# tensor([], size=(0, 4), dtype=torch.int64))
torch.vsplit(my_tensor, (0, 0))
torch.vsplit(my_tensor, (0, -3))
# (tensor([], size=(0, 4), dtype=torch.int64),
# tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))
torch.vsplit(my_tensor, (0, 1))
torch.vsplit(my_tensor, (0, -2))
# (tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[0, 1, 2, 3]]),
# tensor([[4, 5, 6, 7], [8, 9, 10, 11]]))
torch.vsplit(my_tensor, (0, 2))
torch.vsplit(my_tensor, (0, -1))
# (tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[0, 1, 2, 3], [4, 5, 6, 7]]),
# tensor([[8, 9, 10, 11]]))
torch.vsplit(my_tensor, (0, 3))
# (tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),
# tensor([], size=(0, 4), dtype=torch.int64))
torch.vsplit(my_tensor, (1, 0))
torch.vsplit(my_tensor, (1, -3))
# (tensor([[0, 1, 2, 3]]),
# tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))
torch.vsplit(my_tensor, (1, 1))
torch.vsplit(my_tensor, (1, -2))
# (tensor([[0, 1, 2, 3]]),
# tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[4, 5, 6, 7], [8, 9, 10, 11]]))
torch.vsplit(my_tensor, (1, 2))
torch.vsplit(my_tensor, (1, -1))
# (tensor([[0, 1, 2, 3]]),
# tensor([[4, 5, 6, 7]]),
# tensor([[8, 9, 10, 11]]))
torch.vsplit(my_tensor, (1, 3))
# (tensor([[0, 1, 2, 3]]),
# tensor([[4, 5, 6, 7], [8, 9, 10, 11]]),
# tensor([], size=(0, 4), dtype=torch.int64))
torch.vsplit(my_tensor, (2, 0))
torch.vsplit(my_tensor, (2, -3))
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7]]),
# tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))
torch.vsplit(my_tensor, (2, 1))
torch.vsplit(my_tensor, (2, -2))
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7]]),
# tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[4, 5, 6, 7], [8, 9, 10, 11]]))
torch.vsplit(my_tensor, (2, 2))
torch.vsplit(my_tensor, (2, -1))
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7]]),
# tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[8, 9, 10, 11]]))
torch.vsplit(my_tensor, (2, 3))
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7]]),
# tensor([[8, 9, 10, 11]]),
# tensor([], size=(0, 4), dtype=torch.int64))
torch.vsplit(my_tensor, (3, 0))
torch.vsplit(my_tensor, (3, -3))
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),
# tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))
torch.vsplit(my_tensor, (3, 1))
torch.vsplit(my_tensor, (3, -2))
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),
# tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[4, 5, 6, 7], [8, 9, 10, 11]]))
torch.vsplit(my_tensor, (3, 2))
torch.vsplit(my_tensor, (3, -1))
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),
# tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[8, 9, 10, 11]]))
torch.vsplit(my_tensor, (3, 3))
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),
# tensor([], size=(0, 4), dtype=torch.int64),
# tensor([], size=(0, 4), dtype=torch.int64))
torch.vsplit(my_tensor, (-1, 0))
torch.vsplit(my_tensor, (-1, -3))
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7]]),
# tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))
torch.vsplit(my_tensor, (-1, 1))
torch.vsplit(my_tensor, (-1, -2))
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7]]),
# tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[4, 5, 6, 7], [8, 9, 10, 11]]))
torch.vsplit(my_tensor, (-1, 2))
torch.vsplit(my_tensor, (-1, -1))
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7]]),
# tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[8, 9, 10, 11]]))
torch.vsplit(my_tensor, (-1, 3))
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7]]),
# tensor([[8, 9, 10, 11]]),
# tensor([], size=(0, 4), dtype=torch.int64))
torch.vsplit(my_tensor, (-2, 0))
# (tensor([[0, 1, 2, 3]]),
# tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))
torch.vsplit(my_tensor, (-2, 1))
# (tensor([[0, 1, 2, 3]]),
# tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[4, 5, 6, 7], [8, 9, 10, 11]]))
torch.vsplit(my_tensor, (-2, 2))
# (tensor([[0, 1, 2, 3]]),
# tensor([[4, 5, 6, 7]]),
# tensor([[8, 9, 10, 11]]))
torch.vsplit(my_tensor, (-2, 3))
# (tensor([[0, 1, 2, 3]]),
# tensor([[4, 5, 6, 7], [8, 9, 10, 11]]),
# tensor([], size=(0, 4), dtype=torch.int64))
torch.vsplit(my_tensor, (-3, 0))
torch.vsplit(my_tensor, (-3, -3))
# (tensor([], size=(0, 4), dtype=torch.int64),
# tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))
torch.vsplit(my_tensor, (-3, 1))
torch.vsplit(my_tensor, (-3, -2))
# (tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[0, 1, 2, 3]]),
# tensor([[4, 5, 6, 7], [8, 9, 10, 11]]))
torch.vsplit(my_tensor, (-3, 2))
torch.vsplit(my_tensor, (-3, -1))
# (tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[0, 1, 2, 3], [4, 5, 6, 7]]),
# tensor([[8, 9, 10, 11]]))
torch.vsplit(my_tensor, (-3, 3))
# (tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),
# tensor([], size=(0, 4), dtype=torch.int64))
my_tensor = torch.tensor([[False, True, 2., 3.],
[4., 5., 6., 7+0j],
[8+0j, 9+0j, 10+0j, 11+0j]])
torch.vsplit(my_tensor, 1)
# (tensor([[0.+0.j, 1.+0.j, 2.+0.j, 3.+0.j],
# [4.+0.j, 5.+0.j, 6.+0.j, 7.+0.j],
# [8.+0.j, 9.+0.j, 10.+0.j, 11.+0.j]]),)
Top comments (0)