*Memos:
- My post explains dsplit().
- My post explains split().
- My post explains vsplit().
- My post explains hsplit().
- My post explains chunk().
- My post explains unbind().
tensor_split() can get the one or more 1D or more D tensors of zero or more splitted elements from the 1D or more D tensor of zero or more elements as shown below:
*Memos:
-
tensor_split()
can be used with torch or a tensor. - The 1st argument(
input
) withtorch
or using a tensor(Required-Type:tensor
ofint
,float
,complex
orbool
). *It must be a 1D or more D tensor. - The 2nd argument with
torch
or the 1st argument with a tensor issections
(Required-Type:int
). - The 2nd argument with
torch
or the 1st argument with a tensor isindices
(Required-Type:tuple
ofint
orlist
ofint
). - The 2nd argument with
torch
or the 1st argument with a tensor istensor_indices_or_sections
(Required-Type:tensor
ofint
). *It must be a 0D or 1D tensor. - The 3rd argument with
torch
or the 2nd argument with a tensor isdim
(Optional-Default:0
-Type:int
). - The number of the zero or more elements of a tensor changes.
- The total number of the zero or more elements of one or more returned tensors changes.
- One or more returned tensors keep the dimension of
input
tensor.
import torch
my_tensor = torch.tensor([[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11]])
torch.tensor_split(input=my_tensor, sections=1)
my_tensor.tensor_split(sections=1)
torch.tensor_split(input=my_tensor, sections=1, dim=0)
torch.tensor_split(input=my_tensor, sections=1, dim=1)
torch.tensor_split(input=my_tensor, sections=1, dim=-1)
torch.tensor_split(input=my_tensor, sections=1, dim=-2)
torch.tensor_split(input=my_tensor,
tensor_indices_or_sections=torch.tensor(1), dim=0)
torch.tensor_split(input=my_tensor,
tensor_indices_or_sections=torch.tensor(1), dim=1)
torch.tensor_split(input=my_tensor,
tensor_indices_or_sections=torch.tensor(1), dim=-1)
torch.tensor_split(input=my_tensor,
tensor_indices_or_sections=torch.tensor(1), dim=-2)
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),)
torch.tensor_split(input=my_tensor, indices=(1,))
torch.tensor_split(input=my_tensor, indices=(1,), dim=0)
torch.tensor_split(input=my_tensor, indices=(1,), dim=-2)
# (tensor([[0, 1, 2, 3]]),
# tensor([[4, 5, 6, 7], [8, 9, 10, 11]]))
torch.tensor_split(input=my_tensor, indices=(1,), dim=1)
torch.tensor_split(input=my_tensor, indices=(1,), dim=-1)
# (tensor([[0], [4], [8]]),
# tensor([[1, 2, 3], [5, 6, 7], [9, 10, 11]]))
torch.tensor_split(input=my_tensor, sections=2)
torch.tensor_split(input=my_tensor, indices=(2,))
torch.tensor_split(input=my_tensor, sections=2, dim=0)
torch.tensor_split(input=my_tensor, indices=(2,), dim=0)
torch.tensor_split(input=my_tensor, sections=2, dim=-2)
torch.tensor_split(input=my_tensor, indices=(2,), dim=-2)
torch.tensor_split(input=my_tensor, indices=(-1,))
torch.tensor_split(input=my_tensor, indices=(-1,), dim=0)
torch.tensor_split(input=my_tensor, indices=(-1,), dim=-2)
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7]]),
# tensor([[8, 9, 10, 11]]))
torch.tensor_split(input=my_tensor, sections=2, dim=1)
torch.tensor_split(input=my_tensor, indices=(2,), dim=1)
torch.tensor_split(input=my_tensor, sections=2, dim=-1)
torch.tensor_split(input=my_tensor, indices=(2,), dim=-1)
# (tensor([[0, 1], [4, 5], [8, 9]]),
# tensor([[2, 3], [6, 7], [10, 11]]))
torch.tensor_split(input=my_tensor, sections=3)
torch.tensor_split(input=my_tensor, sections=3, dim=0)
torch.tensor_split(input=my_tensor, sections=3, dim=-2)
torch.tensor_split(input=my_tensor, indices=(1, 2))
torch.tensor_split(input=my_tensor, indices=(1, 2), dim=0)
torch.tensor_split(input=my_tensor, indices=(1, 2), dim=-2)
torch.tensor_split(input=my_tensor, indices=(1, -1))
torch.tensor_split(input=my_tensor, indices=(1, -1), dim=0)
torch.tensor_split(input=my_tensor, indices=(1, -1), dim=-2)
torch.tensor_split(input=my_tensor, indices=(-2, 2))
torch.tensor_split(input=my_tensor, indices=(-2, 2), dim=0)
torch.tensor_split(input=my_tensor, indices=(-2, 2), dim=-2)
torch.tensor_split(input=my_tensor, indices=(-2, -1))
torch.tensor_split(input=my_tensor, indices=(-2, -1), dim=0)
torch.tensor_split(input=my_tensor, indices=(-2, -1), dim=-2)
torch.tensor_split(input=my_tensor,
tensor_indices_or_sections=torch.tensor([1, 2]), dim=0)
torch.tensor_split(input=my_tensor,
tensor_indices_or_sections=torch.tensor([1, 2]), dim=-2)
# (tensor([[0, 1, 2, 3]]),
# tensor([[4, 5, 6, 7]]),
# tensor([[8, 9, 10, 11]]))
torch.tensor_split(input=my_tensor, indices=(3,), dim=0)
torch.tensor_split(input=my_tensor, indices=(3,), dim=-2)
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),
# tensor([], size=(0, 4), dtype=torch.int64))
torch.tensor_split(input=my_tensor, indices=(3,), dim=1)
torch.tensor_split(input=my_tensor, indices=(3,), dim=-1)
torch.tensor_split(input=my_tensor, indices=(-1,), dim=1)
torch.tensor_split(input=my_tensor, indices=(-1,), dim=-1)
# (tensor([[0, 1, 2], [4, 5, 6], [8, 9, 10]]),
# tensor([[3], [7], [11]]))
torch.tensor_split(input=my_tensor, sections=3, dim=1)
torch.tensor_split(input=my_tensor, sections=3, dim=-1)
torch.tensor_split(input=my_tensor, indices=(2, 3), dim=1)
torch.tensor_split(input=my_tensor, indices=(2, 3), dim=-1)
torch.tensor_split(input=my_tensor, indices=(2, -1), dim=1)
torch.tensor_split(input=my_tensor, indices=(2, -1), dim=-1)
torch.tensor_split(input=my_tensor, indices=(-2, -1), dim=1)
torch.tensor_split(input=my_tensor, indices=(-2, -1), dim=-1)
# (tensor([[0, 1], [4, 5], [8, 9]]),
# tensor([[2], [6], [10]]),
# tensor([[3], [7], [11]]))
torch.tensor_split(input=my_tensor, indices=(0, 0))
torch.tensor_split(input=my_tensor, indices=(0, 0), dim=0)
torch.tensor_split(input=my_tensor, indices=(0, 0), dim=-2)
torch.tensor_split(input=my_tensor, indices=(0, -3))
torch.tensor_split(input=my_tensor, indices=(0, -3), dim=0)
torch.tensor_split(input=my_tensor, indices=(0, -3), dim=-2)
torch.tensor_split(input=my_tensor, indices=(-3, 0))
torch.tensor_split(input=my_tensor, indices=(-3, 0), dim=0)
torch.tensor_split(input=my_tensor, indices=(-3, 0), dim=-2)
torch.tensor_split(input=my_tensor, indices=(-3, -3))
torch.tensor_split(input=my_tensor, indices=(-3, -3), dim=0)
torch.tensor_split(input=my_tensor, indices=(-3, -3), dim=-2)
torch.tensor_split(input=my_tensor, indices=(-4, -4))
torch.tensor_split(input=my_tensor, indices=(-4, -4), dim=0)
torch.tensor_split(input=my_tensor, indices=(-4, -4), dim=-2)
# (tensor([], size=(0, 4), dtype=torch.int64),
# tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))
torch.tensor_split(input=my_tensor, indices=(0, 0), dim=1)
torch.tensor_split(input=my_tensor, indices=(0, 0), dim=-1)
torch.tensor_split(input=my_tensor, indices=(0, -4), dim=1)
torch.tensor_split(input=my_tensor, indices=(0, -4), dim=-1)
torch.tensor_split(input=my_tensor, indices=(-4, -4), dim=1)
torch.tensor_split(input=my_tensor, indices=(-4, -4), dim=-1)
# (tensor([], size=(3, 0), dtype=torch.int64),
# tensor([], size=(3, 0), dtype=torch.int64),
# tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))
torch.tensor_split(input=my_tensor, indices=(0, 1))
torch.tensor_split(input=my_tensor, indices=(0, 1), dim=0)
torch.tensor_split(input=my_tensor, indices=(0, 1), dim=-2)
torch.tensor_split(input=my_tensor, indices=(0, -2))
torch.tensor_split(input=my_tensor, indices=(0, -2), dim=0)
torch.tensor_split(input=my_tensor, indices=(0, -2), dim=-2)
torch.tensor_split(input=my_tensor, indices=(-3, 1))
torch.tensor_split(input=my_tensor, indices=(-3, 1), dim=0)
torch.tensor_split(input=my_tensor, indices=(-3, 1), dim=-2)
torch.tensor_split(input=my_tensor, indices=(-3, -2))
torch.tensor_split(input=my_tensor, indices=(-3, -2), dim=0)
torch.tensor_split(input=my_tensor, indices=(-3, -2), dim=-2)
# (tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[0, 1, 2, 3]]),
# tensor([[4, 5, 6, 7], [8, 9, 10, 11]]))
torch.tensor_split(input=my_tensor, indices=(0, 1), dim=1)
torch.tensor_split(input=my_tensor, indices=(0, 1), dim=-1)
torch.tensor_split(input=my_tensor, indices=(0, -3), dim=1)
torch.tensor_split(input=my_tensor, indices=(0, -3), dim=-1)
# (tensor([], size=(3, 0), dtype=torch.int64),
# tensor([[0], [4], [8]]),
# tensor([[1, 2, 3], [5, 6, 7], [9, 10, 11]]))
torch.tensor_split(input=my_tensor, indices=(0, 2))
torch.tensor_split(input=my_tensor, indices=(0, 2), dim=0)
torch.tensor_split(input=my_tensor, indices=(0, 2), dim=-2)
torch.tensor_split(input=my_tensor, indices=(0, -1))
torch.tensor_split(input=my_tensor, indices=(0, -1), dim=0)
torch.tensor_split(input=my_tensor, indices=(0, -1), dim=-2)
torch.tensor_split(input=my_tensor, indices=(-3, 2))
torch.tensor_split(input=my_tensor, indices=(-3, 2), dim=0)
torch.tensor_split(input=my_tensor, indices=(-3, 2), dim=-2)
torch.tensor_split(input=my_tensor, indices=(-3, -1))
torch.tensor_split(input=my_tensor, indices=(-3, -1), dim=0)
torch.tensor_split(input=my_tensor, indices=(-3, -1), dim=-2)
# (tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[0, 1, 2, 3], [4, 5, 6, 7]]),
# tensor([[8, 9, 10, 11]]))
torch.tensor_split(input=my_tensor, indices=(0, 2), dim=1)
torch.tensor_split(input=my_tensor, indices=(0, 2), dim=-1)
torch.tensor_split(input=my_tensor, indices=(0, -2), dim=1)
torch.tensor_split(input=my_tensor, indices=(0, -2), dim=-1)
# (tensor([], size=(3, 0), dtype=torch.int64),
# tensor([[0, 1], [4, 5], [8, 9]]),
# tensor([[2, 3], [6, 7], [10, 11]]))
torch.tensor_split(input=my_tensor, indices=(0, 3))
torch.tensor_split(input=my_tensor, indices=(0, 3), dim=0)
torch.tensor_split(input=my_tensor, indices=(0, 3), dim=-2)
torch.tensor_split(input=my_tensor, indices=(-3, 3))
torch.tensor_split(input=my_tensor, indices=(-3, 3), dim=0)
torch.tensor_split(input=my_tensor, indices=(-3, 3), dim=-2)
# (tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),
# tensor([], size=(0, 4), dtype=torch.int64))
torch.tensor_split(input=my_tensor, indices=(0, 3), dim=1)
torch.tensor_split(input=my_tensor, indices=(0, 3), dim=-1)
torch.tensor_split(input=my_tensor, indices=(0, -1), dim=1)
torch.tensor_split(input=my_tensor, indices=(0, -1), dim=-1)
# (tensor([], size=(3, 0), dtype=torch.int64),
# tensor([[0, 1, 2], [4, 5, 6], [8, 9, 10]]),
# tensor([[3], [7], [11]]))
torch.tensor_split(input=my_tensor, indices=(0, 4), dim=1)
torch.tensor_split(input=my_tensor, indices=(0, 4), dim=-1)
torch.tensor_split(input=my_tensor, indices=(-4, 4), dim=1)
torch.tensor_split(input=my_tensor, indices=(-4, 4), dim=-1)
# (tensor([], size=(3, 0), dtype=torch.int64),
# tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),
# tensor([], size=(3, 0), dtype=torch.int64))
torch.tensor_split(input=my_tensor, indices=(1, 0))
torch.tensor_split(input=my_tensor, indices=(1, 0), dim=0)
torch.tensor_split(input=my_tensor, indices=(1, 0), dim=-2)
torch.tensor_split(input=my_tensor, indices=(1, -3))
torch.tensor_split(input=my_tensor, indices=(1, -3), dim=0)
torch.tensor_split(input=my_tensor, indices=(1, -3), dim=-2)
torch.tensor_split(input=my_tensor, indices=(-2, 0))
torch.tensor_split(input=my_tensor, indices=(-2, 0), dim=0)
torch.tensor_split(input=my_tensor, indices=(-2, 0), dim=-2)
torch.tensor_split(input=my_tensor, indices=(-2, -3))
torch.tensor_split(input=my_tensor, indices=(-2, -3), dim=0)
torch.tensor_split(input=my_tensor, indices=(-2, -3), dim=-2)
# (tensor([[0, 1, 2, 3]]),
# tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))
torch.tensor_split(input=my_tensor, indices=(1, 0), dim=1)
torch.tensor_split(input=my_tensor, indices=(1, 0), dim=-1)
torch.tensor_split(input=my_tensor, indices=(1, -4), dim=1)
torch.tensor_split(input=my_tensor, indices=(1, -4), dim=-1)
torch.tensor_split(input=my_tensor, indices=(-3, 0), dim=1)
torch.tensor_split(input=my_tensor, indices=(-3, 0), dim=-1)
torch.tensor_split(input=my_tensor, indices=(-3, -4), dim=1)
torch.tensor_split(input=my_tensor, indices=(-3, -4), dim=-1)
# (tensor([[0], [4], [8]]),
# tensor([], size=(3, 0), dtype=torch.int64),
# tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))
torch.tensor_split(input=my_tensor, indices=(1, 1))
torch.tensor_split(input=my_tensor, indices=(1, 1), dim=0)
torch.tensor_split(input=my_tensor, indices=(1, 1), dim=-2)
torch.tensor_split(input=my_tensor, indices=(1, -2))
torch.tensor_split(input=my_tensor, indices=(1, -2), dim=0)
torch.tensor_split(input=my_tensor, indices=(1, -2), dim=-2)
torch.tensor_split(input=my_tensor, indices=(-2, 1))
torch.tensor_split(input=my_tensor, indices=(-2, 1), dim=0)
torch.tensor_split(input=my_tensor, indices=(-2, 1), dim=-2)
torch.tensor_split(input=my_tensor, indices=(-2, -2))
torch.tensor_split(input=my_tensor, indices=(-2, -2), dim=0)
torch.tensor_split(input=my_tensor, indices=(-2, -2), dim=-2)
# (tensor([[0, 1, 2, 3]]),
# tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[4, 5, 6, 7], [8, 9, 10, 11]]))
torch.tensor_split(input=my_tensor, indices=(1, 1), dim=1)
torch.tensor_split(input=my_tensor, indices=(1, 1), dim=-1)
torch.tensor_split(input=my_tensor, indices=(1, -3), dim=1)
torch.tensor_split(input=my_tensor, indices=(1, -3), dim=-1)
torch.tensor_split(input=my_tensor, indices=(-3, 1), dim=1)
torch.tensor_split(input=my_tensor, indices=(-3, 1), dim=-1)
torch.tensor_split(input=my_tensor, indices=(-3, -3), dim=1)
torch.tensor_split(input=my_tensor, indices=(-3, -3), dim=-1)
# (tensor([[0], [4], [8]]),
# tensor([], size=(3, 0), dtype=torch.int64),
# tensor([[1, 2, 3], [5, 6, 7], [9, 10, 11]]))
torch.tensor_split(input=my_tensor, indices=(1, 2), dim=1)
torch.tensor_split(input=my_tensor, indices=(1, 2), dim=-1)
torch.tensor_split(input=my_tensor, indices=(1, -2), dim=1)
torch.tensor_split(input=my_tensor, indices=(1, -2), dim=-1)
torch.tensor_split(input=my_tensor, indices=(-3, 2), dim=1)
torch.tensor_split(input=my_tensor, indices=(-3, 2), dim=-1)
torch.tensor_split(input=my_tensor, indices=(-3, -2), dim=1)
torch.tensor_split(input=my_tensor, indices=(-3, -2), dim=-1)
torch.tensor_split(input=my_tensor,
tensor_indices_or_sections=torch.tensor([1, 2]), dim=1)
torch.tensor_split(input=my_tensor,
tensor_indices_or_sections=torch.tensor([1, 2]), dim=-1)
# (tensor([[0], [4], [8]]),
# tensor([[1], [5], [9]]),
# tensor([[2, 3], [6, 7], [10, 11]]))
torch.tensor_split(input=my_tensor, indices=(1, 3))
torch.tensor_split(input=my_tensor, indices=(1, 3), dim=0)
torch.tensor_split(input=my_tensor, indices=(1, 3), dim=-2)
# (tensor([[0, 1, 2, 3]]),
# tensor([[4, 5, 6, 7], [8, 9, 10, 11]]),
# tensor([], size=(0, 4), dtype=torch.int64))
torch.tensor_split(input=my_tensor, indices=(1, 3), dim=1)
torch.tensor_split(input=my_tensor, indices=(1, 3), dim=-1)
torch.tensor_split(input=my_tensor, indices=(1, -1), dim=1)
torch.tensor_split(input=my_tensor, indices=(1, -1), dim=-1)
torch.tensor_split(input=my_tensor, indices=(-3, 3), dim=1)
torch.tensor_split(input=my_tensor, indices=(-3, 3), dim=-1)
torch.tensor_split(input=my_tensor, indices=(-3, -1), dim=1)
torch.tensor_split(input=my_tensor, indices=(-3, -1), dim=-1)
# (tensor([[0], [4], [8]]),
# tensor([[1, 2], [5, 6], [9, 10]]),
# tensor([[3], [7], [11]]))
torch.tensor_split(input=my_tensor, indices=(1, 4), dim=1)
torch.tensor_split(input=my_tensor, indices=(1, 4), dim=-1)
torch.tensor_split(input=my_tensor, indices=(-3, 4), dim=1)
torch.tensor_split(input=my_tensor, indices=(-3, 4), dim=-1)
# (tensor([[0], [4], [8]]),
# tensor([[1, 2, 3], [5, 6, 7], [9, 10, 11]]),
# tensor([], size=(3, 0), dtype=torch.int64))
torch.tensor_split(input=my_tensor, indices=(2, 0))
torch.tensor_split(input=my_tensor, indices=(2, 0), dim=0)
torch.tensor_split(input=my_tensor, indices=(2, 0), dim=-2)
torch.tensor_split(input=my_tensor, indices=(2, -3))
torch.tensor_split(input=my_tensor, indices=(2, -3), dim=0)
torch.tensor_split(input=my_tensor, indices=(2, -3), dim=-2)
torch.tensor_split(input=my_tensor, indices=(-1, 0))
torch.tensor_split(input=my_tensor, indices=(-1, 0), dim=0)
torch.tensor_split(input=my_tensor, indices=(-1, 0), dim=-2)
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7]]),
# tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))
torch.tensor_split(input=my_tensor, indices=(2, 0), dim=1)
torch.tensor_split(input=my_tensor, indices=(2, 0), dim=-1)
torch.tensor_split(input=my_tensor, indices=(2, -4), dim=1)
torch.tensor_split(input=my_tensor, indices=(2, -4), dim=-1)
torch.tensor_split(input=my_tensor, indices=(-2, 0), dim=1)
torch.tensor_split(input=my_tensor, indices=(-2, 0), dim=-1)
torch.tensor_split(input=my_tensor, indices=(-2, -4), dim=1)
torch.tensor_split(input=my_tensor, indices=(-2, -4), dim=-1)
# (tensor([[0, 1], [4, 5], [8, 9]]),
# tensor([], size=(3, 0), dtype=torch.int64),
# tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))
torch.tensor_split(input=my_tensor, indices=(2, 1))
torch.tensor_split(input=my_tensor, indices=(2, 1), dim=0)
torch.tensor_split(input=my_tensor, indices=(2, 1), dim=-2)
torch.tensor_split(input=my_tensor, indices=(2, -2))
torch.tensor_split(input=my_tensor, indices=(2, -2), dim=0)
torch.tensor_split(input=my_tensor, indices=(2, -2), dim=-2)
torch.tensor_split(input=my_tensor, indices=(-1, 1))
torch.tensor_split(input=my_tensor, indices=(-1, 1), dim=0)
torch.tensor_split(input=my_tensor, indices=(-1, 1), dim=-2)
torch.tensor_split(input=my_tensor, indices=(-1, -2))
torch.tensor_split(input=my_tensor, indices=(-1, -2), dim=0)
torch.tensor_split(input=my_tensor, indices=(-1, -2), dim=-2)
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7]]),
# tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[4, 5, 6, 7], [8, 9, 10, 11]]))
torch.tensor_split(input=my_tensor, indices=(2, 2))
torch.tensor_split(input=my_tensor, indices=(2, 2), dim=0)
torch.tensor_split(input=my_tensor, indices=(2, 2), dim=-2)
torch.tensor_split(input=my_tensor, indices=(2, -1))
torch.tensor_split(input=my_tensor, indices=(2, -1), dim=0)
torch.tensor_split(input=my_tensor, indices=(2, -1), dim=-2)
torch.tensor_split(input=my_tensor, indices=(-1, 2))
torch.tensor_split(input=my_tensor, indices=(-1, 2), dim=0)
torch.tensor_split(input=my_tensor, indices=(-1, 2), dim=-2)
torch.tensor_split(input=my_tensor, indices=(-1, -1))
torch.tensor_split(input=my_tensor, indices=(-1, -1), dim=0)
torch.tensor_split(input=my_tensor, indices=(-1, -1), dim=-2)
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7]]),
# tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[8, 9, 10, 11]]))
torch.tensor_split(input=my_tensor, indices=(2, 2), dim=1)
torch.tensor_split(input=my_tensor, indices=(2, 2), dim=-1)
torch.tensor_split(input=my_tensor, indices=(2, -2), dim=1)
torch.tensor_split(input=my_tensor, indices=(2, -2), dim=-1)
torch.tensor_split(input=my_tensor, indices=(-2, 2), dim=1)
torch.tensor_split(input=my_tensor, indices=(-2, 2), dim=-1)
torch.tensor_split(input=my_tensor, indices=(-2, -2), dim=1)
torch.tensor_split(input=my_tensor, indices=(-2, -2), dim=-1)
# (tensor([[0, 1], [4, 5], [8, 9]]),
# tensor([], size=(3, 0), dtype=torch.int64),
# tensor([[2, 3], [6, 7], [10, 11]]))
torch.tensor_split(input=my_tensor, indices=(2, 3))
torch.tensor_split(input=my_tensor, indices=(2, 3), dim=0)
torch.tensor_split(input=my_tensor, indices=(2, 3), dim=-2)
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7]]),
# tensor([[8, 9, 10, 11]]),
# tensor([], size=(0, 4), dtype=torch.int64))
torch.tensor_split(input=my_tensor, indices=(2, 4), dim=1)
torch.tensor_split(input=my_tensor, indices=(2, 4), dim=-1)
torch.tensor_split(input=my_tensor, indices=(-2, 4), dim=1)
torch.tensor_split(input=my_tensor, indices=(-2, 4), dim=-1)
# (tensor([[0, 1], [4, 5], [8, 9]]),
# tensor([[2, 3], [6, 7], [10, 11]]),
# tensor([], size=(3, 0), dtype=torch.int64))
torch.tensor_split(input=my_tensor, indices=(3, 0))
torch.tensor_split(input=my_tensor, indices=(3, 0), dim=0)
torch.tensor_split(input=my_tensor, indices=(3, 0), dim=-2)
torch.tensor_split(input=my_tensor, indices=(3, -3))
torch.tensor_split(input=my_tensor, indices=(3, -3), dim=0)
torch.tensor_split(input=my_tensor, indices=(3, -3), dim=-2)
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),
# tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))
torch.tensor_split(input=my_tensor, indices=(3, 0), dim=1)
torch.tensor_split(input=my_tensor, indices=(3, 0), dim=-1)
torch.tensor_split(input=my_tensor, indices=(3, -4), dim=1)
torch.tensor_split(input=my_tensor, indices=(3, -4), dim=-1)
torch.tensor_split(input=my_tensor, indices=(-1, 0), dim=1)
torch.tensor_split(input=my_tensor, indices=(-1, 0), dim=-1)
torch.tensor_split(input=my_tensor, indices=(-1, -4), dim=1)
torch.tensor_split(input=my_tensor, indices=(-1, -4), dim=-1)
# (tensor([[0, 1, 2], [4, 5, 6], [8, 9, 10]]),
# tensor([], size=(3, 0), dtype=torch.int64),
# tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))
torch.tensor_split(input=my_tensor, indices=(3, 1))
torch.tensor_split(input=my_tensor, indices=(3, 1), dim=0)
torch.tensor_split(input=my_tensor, indices=(3, 1), dim=-2)
torch.tensor_split(input=my_tensor, indices=(3, -2))
torch.tensor_split(input=my_tensor, indices=(3, -2), dim=0)
torch.tensor_split(input=my_tensor, (3, -2), dim=-2)
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),
# tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[4, 5, 6, 7], [8, 9, 10, 11]]))
torch.tensor_split(input=my_tensor, indices=(3, 1), dim=1)
torch.tensor_split(input=my_tensor, indices=(3, 1), dim=-1)
torch.tensor_split(input=my_tensor, indices=(3, -3), dim=1)
torch.tensor_split(input=my_tensor, indices=(3, -3), dim=-1)
torch.tensor_split(input=my_tensor, indices=(-1, 1), dim=1)
torch.tensor_split(input=my_tensor, indices=(-1, 1), dim=-1)
# (tensor([[0, 1, 2], [4, 5, 6], [8, 9, 10]]),
# tensor([], size=(3, 0), dtype=torch.int64),
# tensor([[1, 2, 3], [5, 6, 7], [9, 10, 11]]))
torch.tensor_split(input=my_tensor, indices=(3, 2))
torch.tensor_split(input=my_tensor, indices=(3, 2), dim=0)
torch.tensor_split(input=my_tensor, indices=(3, 2), dim=-2)
torch.tensor_split(input=my_tensor, indices=(3, -1))
torch.tensor_split(input=my_tensor, indices=(3, -1), dim=0)
torch.tensor_split(input=my_tensor, indices=(3, -1), dim=-2)
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),
# tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[8, 9, 10, 11]]))
torch.tensor_split(input=my_tensor, indices=(3, 2), dim=1)
torch.tensor_split(input=my_tensor, indices=(3, 2), dim=-1)
torch.tensor_split(input=my_tensor, indices=(3, -2), dim=1)
torch.tensor_split(input=my_tensor, indices=(3, -2), dim=-1)
torch.tensor_split(input=my_tensor, indices=(-1, 2), dim=1)
torch.tensor_split(input=my_tensor, indices=(-1, 2), dim=-1)
torch.tensor_split(input=my_tensor, indices=(-1, -2), dim=1)
torch.tensor_split(input=my_tensor, indices=(-1, -2), dim=-1)
# (tensor([[0, 1, 2], [4, 5, 6], [8, 9, 10]]),
# tensor([], size=(3, 0), dtype=torch.int64),
# tensor([[2, 3], [6, 7], [10, 11]]))
torch.tensor_split(input=my_tensor, indices=(3, 3))
torch.tensor_split(input=my_tensor, indices=(3, 3), dim=0)
torch.tensor_split(input=my_tensor, indices=(3, 3), dim=-2)
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),
# tensor([], size=(0, 4), dtype=torch.int64),
# tensor([], size=(0, 4), dtype=torch.int64))
torch.tensor_split(input=my_tensor, indices=(3, 3), dim=1)
torch.tensor_split(input=my_tensor, indices=(3, 3), dim=-1)
torch.tensor_split(input=my_tensor, indices=(3, -1), dim=1)
torch.tensor_split(input=my_tensor, indices=(3, -1), dim=-1)
torch.tensor_split(input=my_tensor, indices=(-1, -1), dim=1)
torch.tensor_split(input=my_tensor, indices=(-1, -1), dim=-1)
# (tensor([[0, 1, 2], [4, 5, 6], [8, 9, 10]]),
# tensor([], size=(3, 0), dtype=torch.int64),
# tensor([[3], [7], [11]]))
torch.tensor_split(input=my_tensor, indices=(3, 4), dim=1)
torch.tensor_split(input=my_tensor, indices=(3, 4), dim=-1)
torch.tensor_split(input=my_tensor, indices=(-1, 4), dim=1)
torch.tensor_split(input=my_tensor, indices=(-1, 4), dim=-1)
# (tensor([[0, 1, 2], [4, 5, 6], [8, 9, 10]]),
# tensor([[3], [7], [11]]),
# tensor([], size=(3, 0), dtype=torch.int64))
torch.tensor_split(input=my_tensor, indices=(4, 4), dim=1)
torch.tensor_split(input=my_tensor, indices=(4, 4), dim=-1)
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),
# tensor([], size=(3, 0), dtype=torch.int64),
# tensor([], size=(3, 0), dtype=torch.int64))
torch.tensor_split(input=my_tensor, indices=(4, -4), dim=1)
torch.tensor_split(input=my_tensor, indices=(4, -4), dim=-1)
# (tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]),
# tensor([], size=(3, 0), dtype=torch.int64),
# tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))
torch.tensor_split(input=my_tensor, indices=(2, -3), dim=1)
torch.tensor_split(input=my_tensor, indices=(2, -3), dim=-1)
torch.tensor_split(input=my_tensor, indices=(-2, 1), dim=1)
torch.tensor_split(input=my_tensor, indices=(-2, 1), dim=-1)
torch.tensor_split(input=my_tensor, indices=(-2, -3), dim=1)
torch.tensor_split(input=my_tensor, indices=(-2, -3), dim=-1)
# (tensor([[0, 1], [4, 5], [8, 9]]),
# tensor([], size=(3, 0), dtype=torch.int64),
# tensor([[1, 2, 3], [5, 6, 7], [9, 10, 11]]))
torch.tensor_split(input=my_tensor, indices=(0, 0, 0))
torch.tensor_split(input=my_tensor, indices=(0, 0, 0), dim=0)
torch.tensor_split(input=my_tensor, indices=(0, 0, 0), dim=-2)
# (tensor([], size=(0, 4), dtype=torch.int64),
# tensor([], size=(0, 4), dtype=torch.int64),
# tensor([], size=(0, 4), dtype=torch.int64),
# tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]))
etc.
my_tensor = torch.tensor([[0., 1., 2., 3.],
[4., 5., 6., 7.],
[8., 9., 10., 11.]])
torch.tensor_split(input=my_tensor, sections=1)
# (tensor([[0., 1., 2., 3.],
# [4., 5., 6., 7.],
# [8., 9., 10., 11.]]),)
my_tensor = torch.tensor([[0.+0.j, 1.+0.j, 2.+0.j, 3.+0.j],
[4.+0.j, 5.+0.j, 6.+0.j, 7.+0.j],
[8.+0.j, 9.+0.j, 10.+0.j, 11.+0.j]])
torch.tensor_split(input=my_tensor, sections=1)
# (tensor([[0.+0.j, 1.+0.j, 2.+0.j, 3.+0.j],
# [4.+0.j, 5.+0.j, 6.+0.j, 7.+0.j],
# [8.+0.j, 9.+0.j, 10.+0.j, 11.+0.j]]),)
my_tensor = torch.tensor([[True, False, True, False],
[False, True, False, True],
[True, False, True, False]])
torch.tensor_split(input=my_tensor, sections=1)
# (tensor([[True, False, True, False],
# [False, True, False, True],
# [True, False, True, False]]),)
Top comments (0)