DEV Community

Super Kai (Kazuya Ito)
Super Kai (Kazuya Ito)

Posted on • Edited on

cat in PyTorch

Buy Me a Coffee

*Memos:

cat() can get the 1D or more D concatenated tensor of zero or more elements from the one or more 1D or more D tensors of zero or more elements as shown below:

*Memos:

  • cat() can be used with torch but not with a tensor.
  • The 1st argument with torch is tensors(Required-Type:tuple or list of tensor of int, float, complex or bool). *The size of tensors must be the same except dimension 0.
  • The 2nd argument with torch is dim(Optional-Default:0-Type:int).
  • There is out argument with torch(Optional-Default:None-Type:tensor): *Memos:
    • out= must be used.
    • My post explains out argument.
  • concat() is the alias of cat().
import torch

tensor1 = torch.tensor([2, 7, 4])
tensor2 = torch.tensor([8, 3, 2])
tensor3 = torch.tensor([5, 0, 8])

torch.cat(tensors=(tensor1, tensor2, tensor3))
torch.cat(tensors=(tensor1, tensor2, tensor3), dim=0)
torch.cat(tensors=(tensor1, tensor2, tensor3), dim=-1)
# tensor([2, 7, 4, 8, 3, 2, 5, 0, 8])

tensor1 = torch.tensor([2, 7])
tensor2 = torch.tensor([8, 3, 2])
tensor3 = torch.tensor([5])

torch.cat(tensors=(tensor1, tensor2, tensor3))
torch.cat(tensors=(tensor1, tensor2, tensor3), dim=0)
torch.cat(tensors=(tensor1, tensor2, tensor3), dim=-1)
# tensor([2, 7, 8, 3, 2, 5])

tensor1 = torch.tensor([[2, 7, 4], [8, 3, 2]])
tensor2 = torch.tensor([[5, 0, 8], [3, 6, 1]])
tensor3 = torch.tensor([[9, 4, 7], [1, 0, 5]])

torch.cat(tensors=(tensor1, tensor2, tensor3))
torch.cat(tensors=(tensor1, tensor2, tensor3), dim=0)
torch.cat(tensors=(tensor1, tensor2, tensor3), dim=-2)
# tensor([[2, 7, 4],
#         [8, 3, 2],
#         [5, 0, 8],
#         [3, 6, 1],
#         [9, 4, 7],
#         [1, 0, 5]])

torch.cat(tensors=(tensor1, tensor2, tensor3), dim=1)
torch.cat(tensors=(tensor1, tensor2, tensor3), dim=-1)
# tensor([[2, 7, 4, 5, 0, 8, 9, 4, 7],
#         [8, 3, 2, 3, 6, 1, 1, 0, 5]])

tensor1 = torch.tensor([[2, 7, 4], [8, 3, 2]])
tensor2 = torch.tensor([[5, 0, 8], [3, 6, 1], [9, 4, 7]])
tensor3 = torch.tensor([[1, 0, 5]])

torch.cat(tensors=(tensor1, tensor2, tensor3))
torch.cat(tensors=(tensor1, tensor2, tensor3), dim=0)
torch.cat(tensors=(tensor1, tensor2, tensor3), dim=-2)
# tensor([[2, 7, 4],
#        [8, 3, 2],
#        [5, 0, 8],
#        [3, 6, 1],
#        [9, 4, 7],
#        [1, 0, 5]])

tensor1 = torch.tensor([[[2, 7, 4], [8, 3, 2]],
                        [[5, 0, 8], [3, 6, 1]]])
tensor2 = torch.tensor([[[9, 4, 7], [1, 0, 5]],
                        [[6, 7, 4], [2, 1, 9]]])
tensor3 = torch.tensor([[[1, 6, 3], [9, 6, 0]],
                        [[0, 8, 7], [3, 5, 2]]])
torch.cat(tensors=(tensor1, tensor2, tensor3))
torch.cat(tensors=(tensor1, tensor2, tensor3), dim=0)
torch.cat(tensors=(tensor1, tensor2, tensor3), dim=-3)
# tensor([[[2, 7, 4], [8, 3, 2]],
#         [[5, 0, 8], [3, 6, 1]],
#         [[9, 4, 7], [1, 0, 5]],
#         [[6, 7, 4], [2, 1, 9]],
#         [[1, 6, 3], [9, 6, 0]],
#         [[0, 8, 7], [3, 5, 2]]])

torch.cat(tensors=(tensor1, tensor2, tensor3), dim=1)
torch.cat(tensors=(tensor1, tensor2, tensor3), dim=-2)
# tensor([[[2, 7, 4], 
#          [8, 3, 2],
#          [9, 4, 7],
#          [1, 0, 5],
#          [1, 6, 3],
#          [9, 6, 0]],
#         [[5, 0, 8],
#          [3, 6, 1],
#          [6, 7, 4],
#          [2, 1, 9],
#          [0, 8, 7],
#          [3, 5, 2]]])

torch.cat(tensors=(tensor1, tensor2, tensor3), dim=2)
torch.cat(tensors=(tensor1, tensor2, tensor3), dim=-1)
# tensor([[[2, 7, 4, 9, 4, 7, 1, 6, 3],
#          [8, 3, 2, 1, 0, 5, 9, 6, 0]],
#         [[5, 0, 8, 6, 7, 4, 0, 8, 7],
#          [3, 6, 1, 2, 1, 9, 3, 5, 2]]])

tensor1 = torch.tensor([[[2., 7., 4.], [8., 3., 2.]],
                        [[5., 0., 8.], [3., 6., 1.]]])
tensor2 = torch.tensor([[[9., 4., 7.], [1., 0., 5.]],
                        [[6., 7., 4.], [2., 1., 9.]]])
tensor3 = torch.tensor([[[1., 6., 3.], [9., 6., 0.]],
                        [[0., 8., 7.], [3., 5., 2.]]])
torch.cat(tensors=(tensor1, tensor2, tensor3))
# tensor([[[2., 7., 4.], [8., 3., 2.]],
#         [[5., 0., 8.], [3., 6., 1.]],
#         [[9., 4., 7.], [1., 0., 5.]],
#         [[6., 7., 4.], [2., 1., 9.]],
#         [[1., 6., 3.], [9., 6., 0.]],
#         [[0., 8., 7.], [3., 5., 2.]]])

tensor1 = torch.tensor([[[2.+0.j, 7.+0.j, 4.+0.j],
                         [8.+0.j, 3.+0.j, 2.+0.j]],
                        [[5.+0.j, 0.+0.j, 8.+0.j],
                         [3.+0.j, 6.+0.j, 1.+0.j]]])
tensor2 = torch.tensor([[[9.+0.j, 4.+0.j, 7.+0.j],
                         [1.+0.j, 0.+0.j, 5.+0.j]],
                        [[6.+0.j, 7.+0.j, 4.+0.j],
                         [2.+0.j, 1.+0.j, 9.+0.j]]])
tensor3 = torch.tensor([[[1.+0.j, 6.+0.j, 3.+0.j],
                         [9.+0.j, 6.+0.j, 0.+0.j]],
                        [[0.+0.j, 8.+0.j, 7.+0.j],
                         [3.+0.j, 5.+0.j, 2.+0.j]]])
torch.cat(tensors=(tensor1, tensor2, tensor3))
# tensor([[[2.+0.j, 7.+0.j, 4.+0.j],
#          [8.+0.j, 3.+0.j, 2.+0.j]],
#         [[5.+0.j, 0.+0.j, 8.+0.j],
#          [3.+0.j, 6.+0.j, 1.+0.j]],
#         [[9.+0.j, 4.+0.j, 7.+0.j],
#          [1.+0.j, 0.+0.j, 5.+0.j]],
#         [[6.+0.j, 7.+0.j, 4.+0.j],
#          [2.+0.j, 1.+0.j, 9.+0.j]],
#         [[1.+0.j, 6.+0.j, 3.+0.j],
#          [9.+0.j, 6.+0.j, 0.+0.j]],
#         [[0.+0.j, 8.+0.j, 7.+0.j],
#          [3.+0.j, 5.+0.j, 2.+0.j]]])

tensor1 = torch.tensor([[[True, False, True], [True, False, True]],
                        [[False, True, False], [False, True, False]]])
tensor2 = torch.tensor([[[False, True, False], [False, True, False]],
                        [[True, False, True], [True, False, True]]])
tensor3 = torch.tensor([[[True, False, True], [True, False, True]],
                        [[False, True, False], [False, True, False]]])
torch.cat(tensors=(tensor1, tensor2, tensor3))
# tensor([[[True, False, True], [True, False, True]],
#         [[False, True, False], [False, True, False]],
#         [[False, True, False], [False, True, False]],
#         [[True, False, True], [True, False, True]],
#         [[True, False, True], [True, False, True]],
#         [[False, True, False], [False, True, False]]])

tensor1 = torch.tensor([[[0, 1, 2]]])
tensor2 = torch.tensor([])
tensor3 = torch.tensor([[[0, 1, 2]]])

torch.cat(tensors=(tensor1, tensor2, tensor3))
# tensor([[[0., 1., 2.]],
#         [[0., 1., 2.]]])
Enter fullscreen mode Exit fullscreen mode

Top comments (0)