DEV Community

Super Kai (Kazuya Ito)
Super Kai (Kazuya Ito)

Posted on • Edited on

movedim in PyTorch

Buy Me a Coffee

*Memos:

movedim() can get the 0D or more D tensor of zero or more elements with its dimensions moved without losing data from the 0D or more D tensor of zero or more elements as shown below:

*Memos:

  • movedim() can be used with torch or a tensor.
  • The 1st argument(input) with torch or using a tensor(Required-Type:tensor of int, float, complex or bool).
  • The 2nd argument with torch or the 1st argument with a tensor is source(Required-Type:int, tuple of int or list of int). *Each number must be unique.
  • The 3rd argument with torch or the 2nd argument with a tensor is destination(Required-Type:int, tuple of int or list of int). *Each number must be unique.
  • The number of source and destination must be the same.
  • moveaxis() is the alias of movedim().
import torch

my_tensor = torch.tensor([[[0, 1, 2], [3, 4, 5]],
                          [[6, 7, 8], [9, 10, 11]],
                          [[12, 13, 14], [15, 16, 17]],
                          [[18, 19, 20], [21, 22, 23]]])
torch.movedim(input=my_tensor, source=0, destination=0)
torch.movedim(input=my_tensor, source=(0,), destination=(0,))
my_tensor.movedim(source=0, destination=0)
my_tensor.movedim(source=(0,), destination=(0,))
torch.movedim(input=my_tensor, source=0, destination=-3)
torch.movedim(input=my_tensor, source=(0,), destination=(-3,))
...
torch.movedim(input=my_tensor, source=(0, 1), destination=(0, 1))
torch.movedim(input=my_tensor, source=(0, 1), destination=(0, -2))
torch.movedim(input=my_tensor, source=(0, 1), destination=(-3, 1))
...
torch.movedim(input=my_tensor, source=(0, 1, 2), destination=(0, 1, 2))
etc.
# tensor([[[0, 1, 2], [3, 4, 5]],
#         [[6, 7, 8], [9, 10, 11]],
#         [[12, 13, 14], [15, 16, 17]],
#         [[18, 19, 20], [21, 22, 23]]])

torch.movedim(input=my_tensor, source=0, destination=1)
torch.movedim(input=my_tensor, source=(0,), destination=(1,))
torch.movedim(input=my_tensor, source=0, destination=-2)
torch.movedim(input=my_tensor, source=(0,), destination=(-2,))
...
torch.movedim(input=my_tensor, source=(0, 1), destination=(1, 0))
torch.movedim(input=my_tensor, source=(0, 1), destination=(1, -3))
torch.movedim(input=my_tensor, source=(0, 1), destination=(-2, 0))
...
torch.movedim(input=my_tensor, source=(0, 1, 2), destination=(1, 0, 2))
etc.
# tensor([[[0, 1, 2], [6, 7, 8], [12, 13, 14], [18, 19, 20]],
#         [[3, 4, 5], [ 9, 10, 11], [15, 16, 17], [21, 22, 23]]])

torch.movedim(input=my_tensor, source=0, destination=2)
torch.movedim(input=my_tensor, source=(0,), destination=(2,))
torch.movedim(input=my_tensor, source=0, destination=-1)
torch.movedim(input=my_tensor, source=(0,), destination=(-1,))
...
torch.movedim(input=my_tensor, source=(0, 1), destination=(2, 0))
torch.movedim(input=my_tensor, source=(0, 1), destination=(2, -3))
torch.movedim(input=my_tensor, source=(0, 1), destination=(-1, 0))
...
torch.movedim(input=my_tensor, source=(0, 1, 2), destination=(2, 0, 1))
etc.
# tensor([[[0, 6, 12, 18], [1, 7, 13, 19], [2, 8, 14, 20]],
#         [[3, 9, 15, 21], [4, 10, 16, 22], [5, 11, 17, 23]]])

torch.movedim(input=my_tensor, source=(0, 1), destination=(2, 1))
torch.movedim(input=my_tensor, source=(0, 1), destination=(2, -2))
torch.movedim(input=my_tensor, source=(0, 1), destination=(-1, 1))
torch.movedim(input=my_tensor, source=(0, 1), destination=(-1, -2))
torch.movedim(input=my_tensor, source=(0, 2), destination=(2, 0))
...
torch.movedim(input=my_tensor, source=(0, 1, 2), destination=(2, 1, 0))
etc.
# tensor([[[0, 6, 12, 18], [3, 9, 15, 21]],
#         [[1, 7, 13, 19], [4, 10, 16, 22]],
#         [[2, 8, 14, 20], [5, 11, 17, 23]]])

torch.movedim(input=my_tensor, source=1, destination=2)
torch.movedim(input=my_tensor, source=(1,), destination=(2,))
torch.movedim(input=my_tensor, source=1, destination=-1)
torch.movedim(input=my_tensor, source=(1,), destination=(-1,))
...
torch.movedim(input=my_tensor, source=(0, 1), destination=(0, 2))
torch.movedim(input=my_tensor, source=(0, 1), destination=(0, -1))
torch.movedim(input=my_tensor, source=(0, 1), destination=(-3, 2))
...
torch.movedim(input=my_tensor, source=(0, 1, 2), destination=(0, 2, 1))
etc.
# tensor([[[0, 3], [1, 4], [2, 5]],
#         [[6, 9], [7, 10], [8, 11]],
#         [[12, 15], [13, 16], [14, 17]],
#         [[18, 21], [19, 22], [20, 23]]])

torch.movedim(input=my_tensor, source=2, destination=0)
torch.movedim(input=my_tensor, source=(2,), destination=(0,))
torch.movedim(input=my_tensor, source=2, destination=-3)
torch.movedim(input=my_tensor, source=(2,), destination=(-3,))
...
torch.movedim(input=my_tensor, source=(0, 1), destination=(1, 2))
torch.movedim(input=my_tensor, source=(0, 1), destination=(1, -1))
torch.movedim(input=my_tensor, source=(0, 1), destination=(-2, 2))
...
torch.movedim(input=my_tensor, source=(0, 1, 2), destination=(1, 2, 0))
etc.
# tensor([[[0, 3], [6, 9], [12, 15], [18, 21]],
#         [[1, 4], [7, 10], [13, 16], [19, 22]],
#         [[2, 5], [8, 11], [14, 17], [20, 23]]])

my_tensor = torch.tensor([[[0., 1., 2.], [3., 4., 5.]],
                          [[6., 7., 8.], [9., 10., 11.]],
                          [[12., 13., 14.], [15., 16., 17.]],
                          [[18., 19., 20.], [21., 22., 23.]]])
torch.movedim(input=my_tensor, source=0, destination=0)
# tensor([[[0., 1., 2.], [3., 4., 5.]],
#         [[6., 7., 8.], [9., 10., 11.]],
#         [[12., 13., 14.], [15., 16., 17.]],
#         [[18., 19., 20.], [21., 22., 23.]]])

my_tensor = torch.tensor([[[0.+0.j, 1.+0.j, 2.+0.j],
                           [3.+0.j, 4.+0.j, 5.+0.j]],
                          [[6.+0.j, 7.+0.j, 8.+0.j],
                           [9.+0.j, 10.+0.j, 11.+0.j]],
                          [[12.+0.j, 13.+0.j, 14.+0.j],
                           [15.+0.j, 16.+0.j, 17.+0.j]],
                          [[18.+0.j, 19.+0.j, 20.+0.j],
                           [21.+0.j, 22.+0.j, 23.+0.j]]])
torch.movedim(input=my_tensor, source=0, destination=0)
# tensor([[[0.+0.j, 1.+0.j, 2.+0.j],
#          [3.+0.j, 4.+0.j, 5.+0.j]],
#         [[6.+0.j, 7.+0.j, 8.+0.j],
#          [9.+0.j, 10.+0.j, 11.+0.j]],
#         [[12.+0.j, 13.+0.j, 14.+0.j],
#          [15.+0.j, 16.+0.j, 17.+0.j]],
#         [[18.+0.j, 19.+0.j, 20.+0.j],
#          [21.+0.j, 22.+0.j, 23.+0.j]]])

my_tensor = torch.tensor([[[True, False, True], [True, False, True]],
                          [[False, True, False], [False, True, False]],
                          [[True, False, True], [True, False, True]],
                          [[False, True, False], [False, True, False]]])
torch.movedim(input=my_tensor, source=0, destination=0)
# tensor([[[True, False, True], [True, False, True]],
#         [[False, True, False], [False, True, False]],
#         [[True, False, True], [True, False, True]],
#         [[False, True, False], [False, True, False]]])
Enter fullscreen mode Exit fullscreen mode

Top comments (0)