DEV Community

Super Kai (Kazuya Ito)
Super Kai (Kazuya Ito)

Posted on • Updated on

squeeze() and unsqueeze() in PyTorch

squeeze() can remove the zero or more dimensions whose size is 1 from a 0D or more D tensor as shown below:

*Memos:

  • squeeze() can be called both from torch and a tensor.
  • Setting one or more dimensions to the 2nd argument with torch or the 1st argument with a tensor can remove the specific zero or more dimensions whose size is 1. *If the size is not 1, zero or more dimensions are not removed even if you set the one or more dimensions.
import torch

my_tensor = torch.tensor([[[[0], [1]],
                           [[2], [3]],
                           [[4], [5]]]])
                         # The size is [1, 3, 2, 1].
torch.squeeze(my_tensor)
my_tensor.squeeze()
torch.squeeze(my_tensor, (0, 3))
my_tensor.squeeze((0, 3))
torch.squeeze(my_tensor, (0, -1))
my_tensor.squeeze((0, -1))
torch.squeeze(my_tensor, (3, 0))
my_tensor.squeeze((3, 0))
torch.squeeze(my_tensor, (3, -4))
my_tensor.squeeze((3, -4))
torch.squeeze(my_tensor, (-1, 0))
my_tensor.squeeze((-1, 0))
torch.squeeze(my_tensor, (-4, 3))
my_tensor.squeeze((-4, 3))
torch.squeeze(my_tensor, (0, 1, 3))
my_tensor.squeeze((0, 1, 3))
etc.
torch.squeeze(my_tensor, (0, 1, 2, 3))
my_tensor.squeeze((0, 1, 2, 3))
etc.
# tensor([[0, 1],
#         [2, 3],
#         [4, 5]])
# The size is [3, 2].

torch.squeeze(my_tensor, 0)
my_tensor.squeeze(0)
torch.squeeze(my_tensor, (0,))
my_tensor.squeeze((0,))
torch.squeeze(my_tensor, -4)
my_tensor.squeeze(-4)
torch.squeeze(my_tensor, (-4,))
my_tensor.squeeze((-4,))
torch.squeeze(my_tensor, (0, 1))
my_tensor.squeeze((0, 1))
torch.squeeze(my_tensor, (0, 2))
my_tensor.squeeze((0, 2))
torch.squeeze(my_tensor, (0, -2))
my_tensor.squeeze((0, -2))
torch.squeeze(my_tensor, (0, -3))
my_tensor.squeeze((0, -3))
torch.squeeze(my_tensor, (1, 0))
my_tensor.squeeze((1, 0))
torch.squeeze(my_tensor, (1, -4))
my_tensor.squeeze((1, -4))
torch.squeeze(my_tensor, (2, 0))
my_tensor.squeeze((2, 0))
torch.squeeze(my_tensor, (2, -4))
my_tensor.squeeze((2, -4))
torch.squeeze(my_tensor, (-2, 0))
my_tensor.squeeze((-2, 0))
torch.squeeze(my_tensor, (-2, -4))
my_tensor.squeeze((-2, -4))
torch.squeeze(my_tensor, (-3, 0))
my_tensor.squeeze((-3, 0))
torch.squeeze(my_tensor, (-3, -4))
my_tensor.squeeze((-3, -4))
torch.squeeze(my_tensor, (-4, 1))
my_tensor.squeeze((-4, 1))
torch.squeeze(my_tensor, (-4, 2))
my_tensor.squeeze((-4, 2))
torch.squeeze(my_tensor, (0, 1, 2))
my_tensor.squeeze((0, 1, 2))
etc.
# tensor([[[0], [1]],
#         [[2], [3]],
#         [[4], [5]]])
# The size is [3, 2, 1].

torch.squeeze(my_tensor, 1)
my_tensor.squeeze(1)
torch.squeeze(my_tensor, (1,))
my_tensor.squeeze((1,))
torch.squeeze(my_tensor, 2)
my_tensor.squeeze(2)
torch.squeeze(my_tensor, (2,))
my_tensor.squeeze((2,))
torch.squeeze(my_tensor, -2)
my_tensor.squeeze(-2)
torch.squeeze(my_tensor, (-2,))
my_tensor.squeeze((-2,))
torch.squeeze(my_tensor, -3)
my_tensor.squeeze(-3)
torch.squeeze(my_tensor, (-3,))
my_tensor.squeeze((-3,))
torch.squeeze(my_tensor, (1, 2))
my_tensor.squeeze((1, 2))
torch.squeeze(my_tensor, (1, -2))
my_tensor.squeeze((1, -2))
torch.squeeze(my_tensor, (2, 1))
my_tensor.squeeze((2, 1))
torch.squeeze(my_tensor, (2, -3))
my_tensor.squeeze((2, -3))
torch.squeeze(my_tensor, (-2, 1))
my_tensor.squeeze((-2, 1))
torch.squeeze(my_tensor, (-2, -3))
my_tensor.squeeze((-2, -3))
torch.squeeze(my_tensor, (-3, 2))
my_tensor.squeeze((-3, 2))
torch.squeeze(my_tensor, (-3, -2))
my_tensor.squeeze((-3, -2))
etc.
# tensor([[[[0], [1]],
#          [[2], [3]],
#          [[4], [5]]]])
# The size is [1, 3, 2, 1].

torch.squeeze(my_tensor, 3)
my_tensor.squeeze(3)
torch.squeeze(my_tensor, (3,))
my_tensor.squeeze((3,))
torch.squeeze(my_tensor, -1)
my_tensor.squeeze((-1,))
torch.squeeze(my_tensor, (1, 3))
my_tensor.squeeze((1, 3))
torch.squeeze(my_tensor, (1, -1))
my_tensor.squeeze((1, -1))
torch.squeeze(my_tensor, (2, 3))
my_tensor.squeeze((2, 3))
torch.squeeze(my_tensor, (2, -1))
my_tensor.squeeze((2, -1))
torch.squeeze(my_tensor, (3, 1))
my_tensor.squeeze((3, 1))
torch.squeeze(my_tensor, (3, 2))
my_tensor.squeeze((3, 2))
torch.squeeze(my_tensor, (3, -2))
my_tensor.squeeze((3, -2))
torch.squeeze(my_tensor, (3, -3))
my_tensor.squeeze((3, -3))
torch.squeeze(my_tensor, (-1, 1))
my_tensor.squeeze((-1, 1))
torch.squeeze(my_tensor, (-1, 2))
my_tensor.squeeze((-1, 2))
torch.squeeze(my_tensor, (-2, 3))
my_tensor.squeeze((-2, 3))
torch.squeeze(my_tensor, (-2, -1))
my_tensor.squeeze((-2, -1))
torch.squeeze(my_tensor, (-3, 3))
my_tensor.squeeze((-3, 3))
torch.squeeze(my_tensor, (-3, -1))
my_tensor.squeeze((-3, -1))
torch.squeeze(my_tensor, (1, 2, 3))
my_tensor.squeeze((1, 2, 3))
etc.
# tensor([[[0, 1],
#          [2, 3],
#          [4, 5]]])
# The size is [1, 3, 2].
Enter fullscreen mode Exit fullscreen mode

unsqueeze() can add the dimension whose size is 1 to a 0D or more D tensor as shown below:

*Memos:

  • unsqueeze() can be called both from torch and a tensor.
  • The 2nd argument is a dimension with torch.
  • The 1st argument is a dimension with a tensor.
import torch

my_tensor = torch.tensor([[[0, 1, 2, 3], [4, 5, 6, 7]], 
                          [[8, 9, 10, 11], [12, 13, 14, 15]],
                          [[16, 17, 18, 19], [20, 21, 22, 23]]])
                         # The size is [3, 2, 4].
torch.unsqueeze(my_tensor, 0)
my_tensor.unsqueeze(0)
torch.unsqueeze(my_tensor, -4)
my_tensor.unsqueeze(-4)
# tensor([[[[0, 1, 2, 3], [4, 5, 6, 7]],
#          [[8, 9, 10, 11], [12, 13, 14, 15]],
#          [[16, 17, 18, 19], [20, 21, 22, 23]]]])
# The size is [1, 3, 2, 4].

torch.unsqueeze(my_tensor, 1)
my_tensor.unsqueeze(1)
torch.unsqueeze(my_tensor, -3)
my_tensor.unsqueeze(-3)
# tensor([[[[0, 1, 2, 3], [4, 5, 6, 7]]],
#         [[[8, 9, 10, 11], [12, 13, 14, 15]]],
#         [[[16, 17, 18, 19], [20, 21, 22, 23]]]])
# The size is [3, 1, 2, 4].

torch.unsqueeze(my_tensor, 2)
my_tensor.unsqueeze(2)
torch.unsqueeze(my_tensor, -2)
my_tensor.unsqueeze(-2)
# tensor([[[[0, 1, 2, 3]], [[4, 5, 6, 7]]],
#         [[[8, 9, 10, 11]], [[12, 13, 14, 15]]],
#         [[[16, 17, 18, 19]], [[20, 21, 22, 23]]]])
# The size is [3, 2, 1, 4].

torch.unsqueeze(my_tensor, 3)
my_tensor.unsqueeze(3)
torch.unsqueeze(my_tensor, -1)
my_tensor.unsqueeze(-1)
# tensor([[[[0], [1], [2], [3]], [[4], [5], [6], [7]]],
#         [[[8], [9], [10], [11]], [[12], [13], [14], [15]]],
#         [[[16], [17], [18], [19]], [[20], [21], [22], [23]]]])
# The size is [3, 2, 4, 1].
Enter fullscreen mode Exit fullscreen mode

Top comments (0)