DEV Community

Super Kai (Kazuya Ito)
Super Kai (Kazuya Ito)

Posted on

diagonal() and diag_embed() in PyTorch

*My post explains eye(), diag() and diagflat().

diagonal() can extract a tensor from a 2D or more D tensor on the diagonal as shown below:

*Memos:

  • diagonal() can be used with torch or a tensor.
  • The tensor of zero or more integers, floating-point numbers, complex numbers or boolean values can be used.
  • The 2nd argument(int) with torch or the 1st argument(int) with a tensor is offset(Optional-Default:0).
  • The 3rd argument(int) with torch or the 2nd argument(int) with a tensor is dim1(Optional-Default:0).
  • The 4th argument(int) with torch or the 3rd argument(int) with a tensor is dim2(Optional-Default:1).
import torch

my_tensor = torch.tensor([[7, -4, 5],
                          [-6, -3, 8],
                          [9, 1, -2]])
torch.diagonal(my_tensor)
my_tensor.diagonal()
torch.diagonal(my_tensor, offset=0)
torch.diagonal(my_tensor, offset=0, dim1=0, dim2=1)
torch.diagonal(my_tensor, offset=0, dim1=0, dim2=-1)
torch.diagonal(my_tensor, offset=0, dim1=1, dim2=0)
torch.diagonal(my_tensor, offset=0, dim1=1, dim2=-2)
torch.diagonal(my_tensor, offset=0, dim1=-1, dim2=0)
torch.diagonal(my_tensor, offset=0, dim1=-1, dim2=-2)
torch.diagonal(my_tensor, offset=0, dim1=-2, dim2=1)
torch.diagonal(my_tensor, offset=0, dim1=-2, dim2=-1)
# tensor([7, -3, -2])

torch.diagonal(my_tensor, offset=1)
torch.diagonal(my_tensor, offset=1, dim1=0, dim2=1)
torch.diagonal(my_tensor, offset=1, dim1=0, dim2=-1)
torch.diagonal(my_tensor, offset=1, dim1=-2, dim2=1)
torch.diagonal(my_tensor, offset=1, dim1=-2, dim2=-1)
torch.diagonal(my_tensor, offset=-1, dim1=1, dim2=0)
torch.diagonal(my_tensor, offset=-1, dim1=1, dim2=-2)
torch.diagonal(my_tensor, offset=-1, dim1=-1, dim2=0)
torch.diagonal(my_tensor, offset=-1, dim1=-1, dim2=-2)
# tensor([-4, 8])

torch.diagonal(my_tensor, offset=-1)
torch.diagonal(my_tensor, offset=1, dim1=1, dim2=0)
torch.diagonal(my_tensor, offset=1, dim1=1, dim2=-2)
torch.diagonal(my_tensor, offset=1, dim1=-1, dim2=0)
torch.diagonal(my_tensor, offset=1, dim1=-1, dim2=-2)
torch.diagonal(my_tensor, offset=-1, dim1=0, dim2=1)
torch.diagonal(my_tensor, offset=-1, dim1=0, dim2=-1)
torch.diagonal(my_tensor, offset=-1, dim1=-2, dim2=1)
torch.diagonal(my_tensor, offset=-1, dim1=-2, dim2=-1)
# tensor([-6, 1])

torch.diagonal(my_tensor, offset=2)
torch.diagonal(my_tensor, offset=2, dim1=0, dim2=1)
torch.diagonal(my_tensor, offset=2, dim1=0, dim2=-1)
torch.diagonal(my_tensor, offset=2, dim1=-2, dim2=1)
torch.diagonal(my_tensor, offset=2, dim1=-2, dim2=-1)
torch.diagonal(my_tensor, offset=-2, dim1=1, dim2=0)
torch.diagonal(my_tensor, offset=-2, dim1=1, dim2=-2)
torch.diagonal(my_tensor, offset=-2, dim1=-1, dim2=0)
torch.diagonal(my_tensor, offset=-2, dim1=-1, dim2=-2)
# tensor([5])

torch.diagonal(my_tensor, offset=-2)
torch.diagonal(my_tensor, offset=2, dim1=1, dim2=0)
torch.diagonal(my_tensor, offset=2, dim1=1, dim2=-2)
torch.diagonal(my_tensor, offset=2, dim1=-1, dim2=0)
torch.diagonal(my_tensor, offset=2, dim1=-1, dim2=-2)
torch.diagonal(my_tensor, offset=-2, dim1=0, dim2=1)
torch.diagonal(my_tensor, offset=-2, dim1=0, dim2=-1)
torch.diagonal(my_tensor, offset=-2, dim1=-2, dim2=1)
torch.diagonal(my_tensor, offset=-2, dim1=-2, dim2=-1)
# tensor([9])

my_tensor = torch.tensor([[7., -4., 5.],
                          [-6., -3., 8.],
                          [9., 1., -2.]])
torch.diagonal(my_tensor)
# tensor([7., -3., -2.])

my_tensor = torch.tensor([[7+0j, -4+0j, 5+0j],
                          [-6+0j, -3+0j, 8+0j],
                          [9+0j, 1+0j, -2+0j]])
torch.diagonal(my_tensor)
# tensor([7.+0.j, -3.+0.j, -2.+0.j])

my_tensor = torch.tensor([[False, True, True],
                          [True, False, True],
                          [True, True, False]])
torch.diagonal(my_tensor)
# tensor([False, False, False])

my_tensor = torch.tensor([[[7, -4, 5], [-6, -3, 8], [9, 1, -2]],
                          [[3, -1, 8], [0, 1, 6], [-7, 4, -9]],
                          [[6, -8, -9], [-4, 5, 0], [-3, -5, 2]]])
torch.diagonal(my_tensor)
torch.diagonal(my_tensor, offset=0)
torch.diagonal(my_tensor, offset=0, dim1=0, dim2=1)
torch.diagonal(my_tensor, offset=0, dim1=0, dim2=-2)
torch.diagonal(my_tensor, offset=0, dim1=1, dim2=0)
torch.diagonal(my_tensor, offset=0, dim1=1, dim2=-3)
torch.diagonal(my_tensor, offset=0, dim1=-2, dim2=0)
torch.diagonal(my_tensor, offset=0, dim1=-2, dim2=-3)
torch.diagonal(my_tensor, offset=0, dim1=-3, dim2=1)
torch.diagonal(my_tensor, offset=0, dim1=-3, dim2=-2)
# tensor([[7, 0, -3],
#         [-4, 1, -5],
#         [5, 6, 2]])

torch.diagonal(my_tensor, offset=0, dim1=0, dim2=2)
torch.diagonal(my_tensor, offset=0, dim1=0, dim2=-1)
torch.diagonal(my_tensor, offset=0, dim1=2, dim2=0)
torch.diagonal(my_tensor, offset=0, dim1=2, dim2=-3)
torch.diagonal(my_tensor, offset=0, dim1=-1, dim2=0)
torch.diagonal(my_tensor, offset=0, dim1=-1, dim2=-3)
torch.diagonal(my_tensor, offset=0, dim1=-3, dim2=2)
torch.diagonal(my_tensor, offset=0, dim1=-3, dim2=-1)
# tensor([[7, -1, -9],
#         [-6, 1, 0],
#         [9, 4, 2]])

torch.diagonal(my_tensor, offset=0, dim1=1, dim2=2)
torch.diagonal(my_tensor, offset=0, dim1=1, dim2=-1)
torch.diagonal(my_tensor, offset=0, dim1=2, dim2=1)
torch.diagonal(my_tensor, offset=0, dim1=2, dim2=-2)
torch.diagonal(my_tensor, offset=0, dim1=-1, dim2=1)
torch.diagonal(my_tensor, offset=0, dim1=-1, dim2=-2)
torch.diagonal(my_tensor, offset=0, dim1=-2, dim2=2)
torch.diagonal(my_tensor, offset=0, dim1=-2, dim2=-1)
# tensor([[7, -3, -2],
#         [3, 1, -9],
#         [6, 5, 2]])

torch.diagonal(my_tensor, offset=1)
torch.diagonal(my_tensor, offset=1, dim1=0, dim2=1)
torch.diagonal(my_tensor, offset=1, dim1=0, dim2=-2)
torch.diagonal(my_tensor, offset=1, dim1=-3, dim2=1)
torch.diagonal(my_tensor, offset=1, dim1=-3, dim2=-2)
torch.diagonal(my_tensor, offset=-1, dim1=1, dim2=0)
torch.diagonal(my_tensor, offset=-1, dim1=1, dim2=-3)
torch.diagonal(my_tensor, offset=-1, dim1=-2, dim2=0)
torch.diagonal(my_tensor, offset=-1, dim1=-2, dim2=-3)
# tensor([[-6, -7],
#         [-3, 4],
#         [8, -9]])

torch.diagonal(my_tensor, offset=1, dim1=0, dim2=2)
torch.diagonal(my_tensor, offset=1, dim1=0, dim2=-1)
torch.diagonal(my_tensor, offset=1, dim1=-3, dim2=2)
torch.diagonal(my_tensor, offset=1, dim1=-3, dim2=-1)
torch.diagonal(my_tensor, offset=-1, dim1=2, dim2=0)
torch.diagonal(my_tensor, offset=-1, dim1=2, dim2=-3)
torch.diagonal(my_tensor, offset=-1, dim1=-1, dim2=0)
torch.diagonal(my_tensor, offset=-1, dim1=-1, dim2=-3)
# tensor([[-4, 8],
#         [-3, 6],
#         [1, -9]])

torch.diagonal(my_tensor, offset=-1)
torch.diagonal(my_tensor, offset=1, dim1=1, dim2=0)
torch.diagonal(my_tensor, offset=1, dim1=1, dim2=-3)
torch.diagonal(my_tensor, offset=1, dim1=-2, dim2=0)
torch.diagonal(my_tensor, offset=1, dim1=-2, dim2=-3)
torch.diagonal(my_tensor, offset=-1, dim1=0, dim2=1)
torch.diagonal(my_tensor, offset=-1, dim1=0, dim2=-2)
torch.diagonal(my_tensor, offset=-1, dim1=-3, dim2=1)
torch.diagonal(my_tensor, offset=-1, dim1=-3, dim2=-2)
# tensor([[3, -4],
#         [-1, 5],
#         [8, 0]])

torch.diagonal(my_tensor, offset=1, dim1=1, dim2=2)
torch.diagonal(my_tensor, offset=1, dim1=1, dim2=-1)
torch.diagonal(my_tensor, offset=1, dim1=-2, dim2=2)
torch.diagonal(my_tensor, offset=1, dim1=-2, dim2=-1)
torch.diagonal(my_tensor, offset=-1, dim1=2, dim2=1)
torch.diagonal(my_tensor, offset=-1, dim1=2, dim2=-2)
torch.diagonal(my_tensor, offset=-1, dim1=-1, dim2=1)
torch.diagonal(my_tensor, offset=-1, dim1=-1, dim2=-2)
# tensor([[-4, 8],
#         [-1, 6],
#         [-8, 0]])

torch.diagonal(my_tensor, offset=1, dim1=2, dim2=0)
torch.diagonal(my_tensor, offset=1, dim1=2, dim2=-3)
torch.diagonal(my_tensor, offset=1, dim1=-1, dim2=0)
torch.diagonal(my_tensor, offset=1, dim1=-1, dim2=-3)
torch.diagonal(my_tensor, offset=-1, dim1=0, dim2=2)
torch.diagonal(my_tensor, offset=-1, dim1=0, dim2=-1)
torch.diagonal(my_tensor, offset=-1, dim1=-3, dim2=2)
torch.diagonal(my_tensor, offset=-1, dim1=-3, dim2=-1)
# tensor([[3, -8],
#         [0, 5],
#         [-7, -5]])

torch.diagonal(my_tensor, offset=1, dim1=2, dim2=1)
torch.diagonal(my_tensor, offset=1, dim1=2, dim2=-2)
torch.diagonal(my_tensor, offset=1, dim1=-1, dim2=1)
torch.diagonal(my_tensor, offset=1, dim1=-1, dim2=-2)
torch.diagonal(my_tensor, offset=-1, dim1=1, dim2=2)
torch.diagonal(my_tensor, offset=-1, dim1=1, dim2=-1)
torch.diagonal(my_tensor, offset=-1, dim1=-2, dim2=2)
torch.diagonal(my_tensor, offset=-1, dim1=-2, dim2=-1)
# tensor([[-6, 1],
#         [0, 4],
#         [-4, -5]])

torch.diagonal(my_tensor, offset=2)
torch.diagonal(my_tensor, offset=2, dim1=0, dim2=1)
torch.diagonal(my_tensor, offset=2, dim1=0, dim2=-2)
torch.diagonal(my_tensor, offset=2, dim1=-3, dim2=1)
torch.diagonal(my_tensor, offset=2, dim1=-3, dim2=-2)
torch.diagonal(my_tensor, offset=-2, dim1=1, dim2=0)
torch.diagonal(my_tensor, offset=-2, dim1=1, dim2=-3)
torch.diagonal(my_tensor, offset=-2, dim1=-2, dim2=0)
torch.diagonal(my_tensor, offset=-2, dim1=-2, dim2=-3)
# tensor([[9],
#         [1],
#         [-2]])

torch.diagonal(my_tensor, offset=2, dim1=0, dim2=2)
torch.diagonal(my_tensor, offset=2, dim1=0, dim2=-1)
torch.diagonal(my_tensor, offset=2, dim1=-3, dim2=2)
torch.diagonal(my_tensor, offset=2, dim1=-3, dim2=-1)
torch.diagonal(my_tensor, offset=-2, dim1=2, dim2=0)
torch.diagonal(my_tensor, offset=-2, dim1=2, dim2=-3)
torch.diagonal(my_tensor, offset=-2, dim1=-1, dim2=0)
torch.diagonal(my_tensor, offset=-2, dim1=-1, dim2=-3)
# tensor([[5],
#         [8],
#         [-2]])

torch.diagonal(my_tensor, offset=-2)
torch.diagonal(my_tensor, offset=2, dim1=1, dim2=0)
torch.diagonal(my_tensor, offset=2, dim1=1, dim2=-3)
torch.diagonal(my_tensor, offset=2, dim1=-2, dim2=0)
torch.diagonal(my_tensor, offset=2, dim1=-2, dim2=-3)
torch.diagonal(my_tensor, offset=-2, dim1=0, dim2=1)
torch.diagonal(my_tensor, offset=-2, dim1=0, dim2=-2)
torch.diagonal(my_tensor, offset=-2, dim1=-3, dim2=1)
torch.diagonal(my_tensor, offset=-2, dim1=-3, dim2=-2)
# tensor([[6],
#         [-8],
#         [-9]])

torch.diagonal(my_tensor, offset=2, dim1=1, dim2=2)
torch.diagonal(my_tensor, offset=2, dim1=1, dim2=-1)
torch.diagonal(my_tensor, offset=2, dim1=-2, dim2=2)
torch.diagonal(my_tensor, offset=2, dim1=-2, dim2=-1)
torch.diagonal(my_tensor, offset=-2, dim1=2, dim2=1)
torch.diagonal(my_tensor, offset=-2, dim1=2, dim2=-2)
torch.diagonal(my_tensor, offset=-2, dim1=-1, dim2=1)
torch.diagonal(my_tensor, offset=-2, dim1=-1, dim2=-2)
# tensor([[5],
#         [8],
#         [-9]])

torch.diagonal(my_tensor, offset=2, dim1=2, dim2=0)
torch.diagonal(my_tensor, offset=2, dim1=2, dim2=-3)
torch.diagonal(my_tensor, offset=2, dim1=-1, dim2=0)
torch.diagonal(my_tensor, offset=2, dim1=-1, dim2=-3)
torch.diagonal(my_tensor, offset=-2, dim1=0, dim2=2)
torch.diagonal(my_tensor, offset=-2, dim1=0, dim2=-1)
torch.diagonal(my_tensor, offset=-2, dim1=-3, dim2=2)
torch.diagonal(my_tensor, offset=-2, dim1=-3, dim2=-1)
# tensor([[6],
#         [-4],
#         [-3]])

torch.diagonal(my_tensor, offset=2, dim1=2, dim2=1)
torch.diagonal(my_tensor, offset=2, dim1=2, dim2=-2)
torch.diagonal(my_tensor, offset=2, dim1=-1, dim2=1)
torch.diagonal(my_tensor, offset=2, dim1=-1, dim2=-2)
torch.diagonal(my_tensor, offset=-2, dim1=1, dim2=2)
torch.diagonal(my_tensor, offset=-2, dim1=1, dim2=-1)
torch.diagonal(my_tensor, offset=-2, dim1=-2, dim2=2)
torch.diagonal(my_tensor, offset=-2, dim1=-2, dim2=-1)
# tensor([[9],
#         [-7],
#         [-3]])
Enter fullscreen mode Exit fullscreen mode

diag_embed() can create a tensor with a 1D or more D tensor on the diagonal and zero or more 0, 0., 0.+0.j or False elsewhere as shown below:

*Memos:

  • diag_embed() can be used with torch or a tensor.
  • The tensor of zero or more integers, floating-point numbers, complex numbers or boolean values can be used.
  • The 2nd argument(int) with torch or the 1st argument(int) with a tensor is offset(Optional-Default:0).
  • The 3rd argument(int) with torch or the 2nd argument(int) with a tensor is dim1(Optional-Default:-2).
  • The 4th argument(int) with torch or the 3rd argument(int) with a tensor is dim2(Optional-Default:-1).
import torch

my_tensor = torch.tensor([7, -4, 5])

torch.diag_embed(my_tensor)
my_tensor.diag_embed()
torch.diag_embed(my_tensor, offset=0)
torch.diag_embed(my_tensor, offset=0, dim1=0, dim2=1)
torch.diag_embed(my_tensor, offset=0, dim1=0, dim2=-1)
torch.diag_embed(my_tensor, offset=0, dim1=1, dim2=0)
torch.diag_embed(my_tensor, offset=0, dim1=1, dim2=-2)
torch.diag_embed(my_tensor, offset=0, dim1=-1, dim2=0)
torch.diag_embed(my_tensor, offset=0, dim1=-1, dim2=-2)
torch.diag_embed(my_tensor, offset=0, dim1=-2, dim2=1)
torch.diag_embed(my_tensor, offset=0, dim1=-2, dim2=-1)
# tensor([[7, 0, 0],
#         [0, -4, 0],
#         [0, 0, 5]])

torch.diag_embed(my_tensor, offset=1)
torch.diag_embed(my_tensor, offset=1, dim1=0, dim2=1)
torch.diag_embed(my_tensor, offset=1, dim1=0, dim2=-1)
torch.diag_embed(my_tensor, offset=1, dim1=-2, dim2=1)
torch.diag_embed(my_tensor, offset=1, dim1=-2, dim2=-1)
torch.diag_embed(my_tensor, offset=-1, dim1=1, dim2=0)
torch.diag_embed(my_tensor, offset=-1, dim1=1, dim2=-2)
torch.diag_embed(my_tensor, offset=-1, dim1=-1, dim2=0)
torch.diag_embed(my_tensor, offset=-1, dim1=-1, dim2=-2)
# tensor([[0, 7, 0, 0],
#         [0, 0, -4, 0],
#         [0, 0, 0, 5],
#         [0, 0, 0, 0]])

torch.diag_embed(my_tensor, offset=-1)
torch.diag_embed(my_tensor, offset=1, dim1=1, dim2=0)
torch.diag_embed(my_tensor, offset=1, dim1=1, dim2=-2)
torch.diag_embed(my_tensor, offset=1, dim1=-1, dim2=0)
torch.diag_embed(my_tensor, offset=1, dim1=-1, dim2=-2)
torch.diag_embed(my_tensor, offset=-1, dim1=0, dim2=1)
torch.diag_embed(my_tensor, offset=-1, dim1=0, dim2=-1)
torch.diag_embed(my_tensor, offset=-1, dim1=-2, dim2=1)
torch.diag_embed(my_tensor, offset=-1, dim1=-2, dim2=-1)
# tensor([[0, 0, 0, 0],
#         [7, 0, 0, 0],
#         [0, -4, 0, 0],
#         [0, 0, 5, 0]])

torch.diag_embed(my_tensor, offset=2)
torch.diag_embed(my_tensor, offset=2, dim1=0, dim2=1)
torch.diag_embed(my_tensor, offset=2, dim1=0, dim2=-1)
torch.diag_embed(my_tensor, offset=2, dim1=-2, dim2=1)
torch.diag_embed(my_tensor, offset=2, dim1=-2, dim2=-1)
torch.diag_embed(my_tensor, offset=-2, dim1=1, dim2=0)
torch.diag_embed(my_tensor, offset=-2, dim1=1, dim2=-2)
torch.diag_embed(my_tensor, offset=-2, dim1=-1, dim2=0)
torch.diag_embed(my_tensor, offset=-2, dim1=-1, dim2=-2)
# tensor([[0, 0, 7, 0, 0],
#         [0, 0, 0, -4, 0],
#         [0, 0, 0, 0, 5],
#         [0, 0, 0, 0, 0],
#         [0, 0, 0, 0, 0]])

torch.diag_embed(my_tensor, offset=-2)
torch.diag_embed(my_tensor, offset=2, dim1=1, dim2=0)
torch.diag_embed(my_tensor, offset=2, dim1=1, dim2=-2)
torch.diag_embed(my_tensor, offset=2, dim1=-1, dim2=0)
torch.diag_embed(my_tensor, offset=2, dim1=-1, dim2=-2)
torch.diag_embed(my_tensor, offset=-2, dim1=0, dim2=1)
torch.diag_embed(my_tensor, offset=-2, dim1=0, dim2=-1)
torch.diag_embed(my_tensor, offset=-2, dim1=-2, dim2=1)
torch.diag_embed(my_tensor, offset=-2, dim1=-2, dim2=-1)
# tensor([[0, 0, 0, 0, 0],
#         [0, 0, 0, 0, 0],
#         [7, 0, 0, 0, 0],
#         [0, -4, 0, 0, 0],
#         [0, 0, 5, 0, 0]])

my_tensor = torch.tensor([7., -4., 5.])

torch.diag_embed(my_tensor)
# tensor([[7., 0., 0.],
#         [0., -4., 0.],
#         [0., 0., 5.]])

my_tensor = torch.tensor([7+0j, -4+0j, 5+0j])

torch.diag_embed(my_tensor)
# tensor([[7.+0.j, 0.+0.j, 0.+0.j],
#         [0.+0.j, -4.+0.j, 0.+0.j],
#         [0.+0.j, 0.+0.j, 5.+0.j]])

my_tensor = torch.tensor([True, True, True])

torch.diag_embed(my_tensor)
# tensor([[True, False, False],
#         [False, True, False],
#         [False, False, True]])

my_tensor = torch.tensor([[7, -4, 5],
                          [-6, -3, 8],
                          [9, 1, -2]])
torch.diag_embed(my_tensor)
torch.diag_embed(my_tensor, offset=0)
torch.diag_embed(my_tensor, offset=0, dim1=0, dim2=1)
torch.diag_embed(my_tensor, offset=0, dim1=0, dim2=-2)
torch.diag_embed(my_tensor, offset=0, dim1=1, dim2=0)
torch.diag_embed(my_tensor, offset=0, dim1=1, dim2=-3)
torch.diag_embed(my_tensor, offset=0, dim1=-2, dim2=0)
torch.diag_embed(my_tensor, offset=0, dim1=-2, dim2=-3)
torch.diag_embed(my_tensor, offset=0, dim1=-3, dim2=1)
torch.diag_embed(my_tensor, offset=0, dim1=-3, dim2=-2)
# tensor([[[7, -6, 9],
#          [0, 0, 0],
#          [0, 0, 0]],
#         [[0, 0, 0],
#          [-4, -3, 1],
#          [0, 0, 0]],
#         [[0, 0, 0],
#          [0, 0, 0],
#          [5, 8, -2]]])

torch.diag_embed(my_tensor, offset=0, dim1=0, dim2=2)
torch.diag_embed(my_tensor, offset=0, dim1=0, dim2=-1)
torch.diag_embed(my_tensor, offset=0, dim1=2, dim2=0)
torch.diag_embed(my_tensor, offset=0, dim1=2, dim2=-3)
torch.diag_embed(my_tensor, offset=0, dim1=-1, dim2=0)
torch.diag_embed(my_tensor, offset=0, dim1=-1, dim2=-3)
torch.diag_embed(my_tensor, offset=0, dim1=-3, dim2=2)
torch.diag_embed(my_tensor, offset=0, dim1=-3, dim2=-1)
# tensor([[[7, 0, 0],
#          [-6, 0, 0],
#          [9, 0, 0]],
#         [[0, -4, 0],
#          [0, -3, 0],
#          [0, 1, 0]],
#         [[0, 0, 5],
#          [0, 0, 8],
#          [0, 0, -2]]])

torch.diag_embed(my_tensor, offset=0, dim1=1, dim2=2)
torch.diag_embed(my_tensor, offset=0, dim1=1, dim2=-1)
torch.diag_embed(my_tensor, offset=0, dim1=2, dim2=1)
torch.diag_embed(my_tensor, offset=0, dim1=2, dim2=-2)
torch.diag_embed(my_tensor, offset=0, dim1=-1, dim2=1)
torch.diag_embed(my_tensor, offset=0, dim1=-1, dim2=-2)
torch.diag_embed(my_tensor, offset=0, dim1=-2, dim2=2)
torch.diag_embed(my_tensor, offset=0, dim1=-2, dim2=-1)
# tensor([[[7, 0, 0],
#          [0, -4, 0],
#          [0, 0, 5]],
#         [[-6, 0, 0],
#          [0, -3, 0],
#          [0, 0, 8]],
#         [[9, 0, 0],
#          [0, 1, 0],
#          [0, 0, -2]]])

torch.diag_embed(my_tensor, offset=1, dim1=0, dim2=1)
torch.diag_embed(my_tensor, offset=1, dim1=0, dim2=-2)
torch.diag_embed(my_tensor, offset=1, dim1=-3, dim2=1)
torch.diag_embed(my_tensor, offset=1, dim1=-3, dim2=-2)
torch.diag_embed(my_tensor, offset=-1, dim1=1, dim2=0)
torch.diag_embed(my_tensor, offset=-1, dim1=1, dim2=-3)
torch.diag_embed(my_tensor, offset=-1, dim1=-2, dim2=0)
torch.diag_embed(my_tensor, offset=-1, dim1=-2, dim2=-3)
# tensor([[[0, 0, 0],
#          [7, -6, 9],
#          [0, 0, 0],
#          [0, 0, 0]],
#         [[0, 0, 0],
#          [0, 0, 0],
#          [-4, -3, 1],
#          [0, 0, 0]],
#         [[0, 0, 0],
#          [0, 0, 0],
#          [0, 0, 0],
#          [5, 8, -2]],
#         [[0, 0, 0],
#          [0, 0, 0],
#          [0, 0, 0],
#          [0, 0, 0]]])

torch.diag_embed(my_tensor, offset=1, dim1=0, dim2=2)
torch.diag_embed(my_tensor, offset=1, dim1=0, dim2=-1)
torch.diag_embed(my_tensor, offset=1, dim1=-3, dim2=2)
torch.diag_embed(my_tensor, offset=1, dim1=-3, dim2=-1)
torch.diag_embed(my_tensor, offset=-1, dim1=2, dim2=0)
torch.diag_embed(my_tensor, offset=-1, dim1=2, dim2=-3)
torch.diag_embed(my_tensor, offset=-1, dim1=-1, dim2=0)
torch.diag_embed(my_tensor, offset=-1, dim1=-1, dim2=-3)
# tensor([[[0, 7, 0, 0],
#          [0, -6, 0, 0],
#          [0, 9, 0, 0]],
#         [[0, 0, -4, 0],
#          [0, 0, -3, 0],
#          [0, 0, 1, 0]],
#         [[0, 0, 0, 5],
#          [0, 0, 0, 8],
#          [0, 0, 0, -2]],
#         [[0, 0, 0, 0],
#          [0, 0, 0, 0],
#          [0, 0, 0, 0]]])

torch.diag_embed(my_tensor, offset=1, dim1=1, dim2=0)
torch.diag_embed(my_tensor, offset=1, dim1=1, dim2=-3)
torch.diag_embed(my_tensor, offset=1, dim1=-2, dim2=0)
torch.diag_embed(my_tensor, offset=1, dim1=-2, dim2=-3)
torch.diag_embed(my_tensor, offset=-1, dim1=0, dim2=1)
torch.diag_embed(my_tensor, offset=-1, dim1=0, dim2=-2)
torch.diag_embed(my_tensor, offset=-1, dim1=-3, dim2=1)
torch.diag_embed(my_tensor, offset=-1, dim1=-3, dim2=-2)
# tensor([[[0, 0, 0],
#          [0, 0, 0],
#          [0, 0, 0],
#          [0, 0, 0]],
#         [[7, -6, 9],
#          [0, 0, 0],
#          [0, 0, 0],
#          [0, 0, 0]],
#         [[0, 0, 0],
#          [-4, -3, 1],
#          [0, 0, 0],
#          [0, 0, 0]],
#         [[0, 0, 0],
#          [0, 0, 0],
#          [5, 8, -2],
#          [0, 0, 0]]])

torch.diag_embed(my_tensor, offset=1)
torch.diag_embed(my_tensor, offset=1, dim1=1, dim2=2)
torch.diag_embed(my_tensor, offset=1, dim1=1, dim2=-1)
torch.diag_embed(my_tensor, offset=1, dim1=-2, dim2=2)
torch.diag_embed(my_tensor, offset=1, dim1=-2, dim2=-1)
torch.diag_embed(my_tensor, offset=-1, dim1=2, dim2=1)
torch.diag_embed(my_tensor, offset=-1, dim1=2, dim2=-2)
torch.diag_embed(my_tensor, offset=-1, dim1=-1, dim2=1)
torch.diag_embed(my_tensor, offset=-1, dim1=-1, dim2=-2)
# tensor([[[0, 7, 0, 0],
#          [0, 0, -4, 0],
#          [0, 0, 0, 5],
#          [0, 0, 0, 0]],
#         [[0, -6, 0, 0],
#          [0, 0, -3, 0],
#          [0, 0, 0, 8],
#          [0, 0, 0, 0]],
#         [[0, 9, 0, 0],
#          [0, 0, 1, 0],
#          [0, 0, 0, -2],
#          [0, 0, 0, 0]]])

torch.diag_embed(my_tensor, offset=1, dim1=2, dim2=0)
torch.diag_embed(my_tensor, offset=1, dim1=2, dim2=-3)
torch.diag_embed(my_tensor, offset=1, dim1=-1, dim2=0)
torch.diag_embed(my_tensor, offset=1, dim1=-1, dim2=-3)
torch.diag_embed(my_tensor, offset=-1, dim1=0, dim2=2)
torch.diag_embed(my_tensor, offset=-1, dim1=0, dim2=-1)
torch.diag_embed(my_tensor, offset=-1, dim1=-3, dim2=2)
torch.diag_embed(my_tensor, offset=-1, dim1=-3, dim2=-1)
# tensor([[[0, 0, 0, 0],
#          [0, 0, 0, 0],
#          [0, 0, 0, 0]],
#         [[7, 0, 0, 0],
#          [-6, 0, 0, 0],
#          [9, 0, 0, 0]],
#         [[0, -4, 0, 0],
#          [0, -3, 0, 0],
#          [0, 1, 0, 0]],
#         [[0, 0, 5, 0],
#          [0, 0, 8, 0],
#          [0, 0, -2, 0]]])

torch.diag_embed(my_tensor, offset=-1)
torch.diag_embed(my_tensor, offset=1, dim1=2, dim2=1)
torch.diag_embed(my_tensor, offset=1, dim1=2, dim2=-2)
torch.diag_embed(my_tensor, offset=1, dim1=-1, dim2=1)
torch.diag_embed(my_tensor, offset=1, dim1=-1, dim2=-2)
torch.diag_embed(my_tensor, offset=-1, dim1=1, dim2=2)
torch.diag_embed(my_tensor, offset=-1, dim1=1, dim2=-1)
torch.diag_embed(my_tensor, offset=-1, dim1=-2, dim2=2)
torch.diag_embed(my_tensor, offset=-1, dim1=-2, dim2=-1)
# tensor([[[0, 0, 0, 0],
#          [7, 0, 0, 0],
#          [0, -4, 0, 0],
#          [0, 0, 5, 0]],
#         [[0, 0, 0, 0],
#          [-6, 0, 0, 0],
#          [0, -3, 0, 0],
#          [0, 0, 8, 0]],
#         [[0, 0, 0, 0],
#          [9, 0, 0, 0],
#          [0, 1, 0, 0],
#          [0, 0, -2, 0]]])

torch.diag_embed(my_tensor, offset=2, dim1=0, dim2=1)
torch.diag_embed(my_tensor, offset=2, dim1=0, dim2=-2)
torch.diag_embed(my_tensor, offset=2, dim1=-3, dim2=1)
torch.diag_embed(my_tensor, offset=2, dim1=-3, dim2=-2)
torch.diag_embed(my_tensor, offset=-2, dim1=1, dim2=0)
torch.diag_embed(my_tensor, offset=-2, dim1=1, dim2=-3)
torch.diag_embed(my_tensor, offset=-2, dim1=-2, dim2=0)
torch.diag_embed(my_tensor, offset=-2, dim1=-2, dim2=-3)
# tensor([[[0, 0, 0],
#          [0, 0, 0],
#          [7, -6, 9],
#          [0, 0, 0],
#          [0, 0, 0]],
#         [[0, 0, 0],
#          [0, 0, 0],
#          [0, 0, 0],
#          [-4, -3, 1],
#          [0, 0, 0]],
#         [[0, 0, 0],
#          [0, 0, 0],
#          [0, 0, 0],
#          [0, 0, 0],
#          [5, 8, -2]],
#         [[0, 0, 0],
#          [0, 0, 0],
#          [0, 0, 0],
#          [0, 0, 0],
#          [0, 0, 0]],
#         [[0, 0, 0],
#          [0, 0, 0],
#          [0, 0, 0],
#          [0, 0, 0],
#          [0, 0, 0]]])

torch.diag_embed(my_tensor, offset=2, dim1=0, dim2=2)
torch.diag_embed(my_tensor, offset=2, dim1=0, dim2=-1)
torch.diag_embed(my_tensor, offset=2, dim1=-3, dim2=2)
torch.diag_embed(my_tensor, offset=2, dim1=-3, dim2=-1)
torch.diag_embed(my_tensor, offset=-2, dim1=2, dim2=0)
torch.diag_embed(my_tensor, offset=-2, dim1=2, dim2=-3)
torch.diag_embed(my_tensor, offset=-2, dim1=-1, dim2=0)
torch.diag_embed(my_tensor, offset=-2, dim1=-1, dim2=-3)
# tensor([[[0, 0, 7, 0, 0],
#          [0, 0, -6, 0, 0],
#          [0, 0, 9, 0, 0]],
#         [[0, 0, 0, -4, 0],
#          [0, 0, 0, -3, 0],
#          [0, 0, 0, 1, 0]],
#         [[0, 0, 0, 0, 5],
#          [0, 0, 0, 0, 8],
#          [0, 0, 0, 0, -2]],
#         [[0, 0, 0, 0, 0],
#          [0, 0, 0, 0, 0],
#          [0, 0, 0, 0, 0]],
#         [[0, 0, 0, 0, 0],
#          [0, 0, 0, 0, 0],
#          [0, 0, 0, 0, 0]]])

torch.diag_embed(my_tensor, offset=2, dim1=1, dim2=0)
torch.diag_embed(my_tensor, offset=2, dim1=1, dim2=-3)
torch.diag_embed(my_tensor, offset=2, dim1=-2, dim2=0)
torch.diag_embed(my_tensor, offset=2, dim1=-2, dim2=-3)
torch.diag_embed(my_tensor, offset=-2, dim1=0, dim2=1)
torch.diag_embed(my_tensor, offset=-2, dim1=0, dim2=-2)
torch.diag_embed(my_tensor, offset=-2, dim1=-3, dim2=1)
torch.diag_embed(my_tensor, offset=-2, dim1=-3, dim2=-2)
# tensor([[[0, 0, 0],
#          [0, 0, 0],
#          [0, 0, 0],
#          [0, 0, 0],
#          [0, 0, 0]],
#         [[0, 0, 0],
#          [0, 0, 0],
#          [0, 0, 0],
#          [0, 0, 0],
#          [0, 0, 0]],
#         [[7, -6, 9],
#          [0, 0, 0],
#          [0, 0, 0],
#          [0, 0, 0],
#          [0, 0, 0]],
#         [[0, 0, 0],
#          [-4, -3, 1],
#          [0, 0, 0],
#          [0, 0, 0],
#          [0, 0, 0]],
#         [[0, 0, 0],
#          [0, 0, 0],
#          [5, 8, -2],
#          [0, 0, 0],
#          [0, 0, 0]]])

torch.diag_embed(my_tensor, offset=2)
torch.diag_embed(my_tensor, offset=2, dim1=1, dim2=2)
torch.diag_embed(my_tensor, offset=2, dim1=1, dim2=-1)
torch.diag_embed(my_tensor, offset=2, dim1=-2, dim2=2)
torch.diag_embed(my_tensor, offset=2, dim1=-2, dim2=-1)
torch.diag_embed(my_tensor, offset=-2, dim1=2, dim2=1)
torch.diag_embed(my_tensor, offset=-2, dim1=2, dim2=-2)
torch.diag_embed(my_tensor, offset=-2, dim1=-1, dim2=1)
torch.diag_embed(my_tensor, offset=-2, dim1=-1, dim2=-2)
# tensor([[[0, 0, 7, 0, 0],
#          [0, 0, 0, -4, 0],
#          [0, 0, 0, 0, 5],
#          [0, 0, 0, 0, 0],
#          [0, 0, 0, 0, 0]],
#         [[0, 0, -6, 0, 0],
#          [0, 0, 0,-3, 0],
#          [0, 0, 0, 0, 8],
#          [0, 0, 0, 0, 0],
#          [0, 0, 0, 0, 0]],
#         [[0, 0, 9, 0, 0],
#          [0, 0, 0, 1, 0],
#          [0, 0, 0, 0, -2],
#          [0, 0, 0, 0, 0],
#          [0, 0, 0, 0, 0]]])

torch.diag_embed(my_tensor, offset=2, dim1=2, dim2=0)
torch.diag_embed(my_tensor, offset=2, dim1=2, dim2=-3)
torch.diag_embed(my_tensor, offset=2, dim1=-1, dim2=0)
torch.diag_embed(my_tensor, offset=2, dim1=-1, dim2=-3)
torch.diag_embed(my_tensor, offset=-2, dim1=0, dim2=2)
torch.diag_embed(my_tensor, offset=-2, dim1=0, dim2=-1)
torch.diag_embed(my_tensor, offset=-2, dim1=-3, dim2=2)
torch.diag_embed(my_tensor, offset=-2, dim1=-3, dim2=-1)
# tensor([[[0, 0, 0, 0, 0],
#          [0, 0, 0, 0, 0],
#          [0, 0, 0, 0, 0]],
#         [[0, 0, 0, 0, 0],
#          [0, 0, 0, 0, 0],
#          [0, 0, 0, 0, 0]],
#         [[7, 0, 0, 0, 0],
#          [-6, 0, 0, 0, 0],
#          [9, 0, 0, 0, 0]],
#         [[0, -4, 0, 0, 0],
#          [0, -3, 0, 0, 0],
#          [0, 1, 0, 0, 0]],
#         [[0, 0, 5, 0, 0],
#          [0, 0, 8, 0, 0],
#          [0, 0, -2, 0, 0]]])

torch.diag_embed(my_tensor, offset=-2)
torch.diag_embed(my_tensor, offset=2, dim1=2, dim2=1)
torch.diag_embed(my_tensor, offset=2, dim1=2, dim2=-2)
torch.diag_embed(my_tensor, offset=2, dim1=-1, dim2=1)
torch.diag_embed(my_tensor, offset=2, dim1=-1, dim2=-2)
torch.diag_embed(my_tensor, offset=-2, dim1=1, dim2=2)
torch.diag_embed(my_tensor, offset=-2, dim1=1, dim2=-1)
torch.diag_embed(my_tensor, offset=-2, dim1=-2, dim2=2)
torch.diag_embed(my_tensor, offset=-2, dim1=-2, dim2=-1)
# tensor([[[0, 0, 0, 0, 0],
#          [0, 0, 0, 0, 0],
#          [7, 0, 0, 0, 0],
#          [0, -4, 0, 0, 0],
#          [0, 0, 5, 0, 0]],
#         [[0, 0, 0, 0, 0],
#          [0, 0, 0, 0, 0],
#          [-6, 0, 0, 0, 0],
#          [0, -3, 0, 0, 0],
#          [0, 0, 8, 0, 0]],
#         [[0, 0, 0, 0, 0],
#          [0, 0, 0, 0, 0],
#          [9, 0, 0, 0, 0],
#          [0, 1, 0, 0, 0],
#          [0, 0, -2, 0, 0]]])
Enter fullscreen mode Exit fullscreen mode

Top comments (0)