DEV Community

Super Kai (Kazuya Ito)
Super Kai (Kazuya Ito)

Posted on • Edited on

Type promotion, result_type(), promote_types() and can_cast() in PyTorch

Buy Me a Coffee

*My post explains how to create and acceess a tensor.

Arithmetic operation like (addition, subtraction, multiplication, etc) has type promotion in PyTorch as shown below:

*Memos:

  • The priority of type promotion is complex <- float <- int <- bool.
  • If types' category is the same but the size is different, the higher size's one is taken. (e.g. int32 and int64)
import torch

tensor1 = torch.tensor([True, False, True], dtype=torch.bool)
tensor2 = torch.tensor([3, 4, 5], dtype=torch.int64)

torch.add(input=tensor1, other=tensor2)
# tensor([4, 4, 6])

torch.add(input=tensor1, other=tensor2).dtype
# torch.int64

tensor1 = torch.tensor([0, 1, 2], dtype=torch.int64)
tensor2 = torch.tensor([3., 4., 5.], dtype=torch.float32)

torch.add(input=tensor1, other=tensor2)
# tensor([3., 5., 7.])

torch.add(input=tensor1, other=tensor2).dtype
# torch.float32

tensor1 = torch.tensor([0., 1., 2.], dtype=torch.float32)
tensor2 = torch.tensor([3.+0.j, 4.+0.j, 5.+0.j], 
                       dtype=torch.complex64)
torch.add(input=tensor1, other=tensor2)
# tensor([3.+0.j, 5.+0.j, 7.+0.j])

torch.add(input=tensor1, other=tensor2).dtype
# torch.complex64

tensor1 = torch.tensor([0, 1, 2], dtype=torch.int32)
tensor2 = torch.tensor([3, 4, 5], dtype=torch.int64)

torch.add(input=tensor1, other=tensor2)
# tensor([3, 5, 7])

torch.add(input=tensor1, other=tensor2).dtype
# torch.int64

tensor1 = torch.tensor([0., 1., 2.], dtype=torch.float32)
tensor2 = torch.tensor([3., 4., 5.], dtype=torch.float64)

torch.add(input=tensor1, other=tensor2)
# tensor([3., 5., 7.], dtype=torch.float64)

torch.add(input=tensor1, other=tensor2).dtype
# torch.float64

tensor1 = torch.tensor([0.+0.j, 1.+0.j, 2.+0.j], 
                       dtype=torch.complex32)
tensor2 = torch.tensor([3.+0.j, 4.+0.j, 5.+0.j], 
                       dtype=torch.complex64)
torch.add(input=tensor1, other=tensor2)
# tensor([3.+0.j, 5.+0.j, 7.+0.j])

torch.add(input=tensor1, other=tensor2).dtype
# torch.complex64
Enter fullscreen mode Exit fullscreen mode

result_type() can check the result type of two of the 0D or more D tersors of zero or more elements or scalars or the 0D or more D tensor of zero or more elements and a scalar, getting a dtype as shown below:

*Memos:

  • result_type() can be used with torch but not with a tensor.
  • The 1st argument with torch is tensor(Required-Type:tensor or scalar of int, float, complex or bool). *A scalar must used without tensor=.
  • The 2nd argument with torch is other(Required-Type:tensor or scalar of int, float, complex or bool).
import torch

tensor1 = torch.tensor([True, False, True], dtype=torch.bool)
tensor2 = torch.tensor([3, 4, 5], dtype=torch.int64)

torch.result_type(tensor=tensor1, other=tensor2)
torch.result_type(tensor=True, other=tensor2)
torch.result_type(tensor=tensor1, other=3)
torch.result_type(True, 3)
# torch.int64

tensor1 = torch.tensor([0, 1, 2], dtype=torch.int64)
tensor2 = torch.tensor([3., 4., 5.], dtype=torch.float32)

torch.result_type(tensor=tensor1, other=tensor2)
# torch.float32

tensor1 = torch.tensor([0., 1., 2.], dtype=torch.float32)
tensor2 = torch.tensor([3.+0.j, 4.+0.j, 5.+0.j], 
                       dtype=torch.complex64)
torch.result_type(tensor=tensor1, other=tensor2)
# torch.complex64

tensor1 = torch.tensor([0, 1, 2], dtype=torch.int32)
tensor2 = torch.tensor([3, 4, 5], dtype=torch.int64)

torch.result_type(tensor=tensor1, other=tensor2)
# torch.int64

tensor1 = torch.tensor([0., 1., 2.], dtype=torch.float32)
tensor2 = torch.tensor([3., 4., 5.], dtype=torch.float64)

torch.result_type(tensor=tensor1, other=tensor2)
# torch.float64

tensor1 = torch.tensor([0.+0.j, 1.+0.j, 2.+0.j], 
                       dtype=torch.complex32)
tensor2 = torch.tensor([3.+0.j, 4.+0.j, 5.+0.j], 
                       dtype=torch.complex64)
torch.result_type(tensor=tensor1, other=tensor2)
# torch.complex64
Enter fullscreen mode Exit fullscreen mode

promote_types() can check the result type of two of dtypes, getting a dtype as shown below:

*Memos:

  • promote_types() can be used with torch but not with a tensor.
  • The 1st argument with torch is type1(Required-Type:dtype).
  • The 2nd argument with torch is type2(Required-Type:dtype).
  • type1 and type2 can also accept int(), float() and bool() but not complex() which are python built-in functions.
import torch

torch.promote_types(type1=torch.bool, type2=torch.int64)
torch.promote_types(type1=torch.bool, type2=int)
torch.promote_types(type1=bool, type2=torch.int64)
torch.promote_types(type1=bool, type2=int)
torch.promote_types(type1=torch.int32, type2=torch.int64)
# torch.int64

torch.promote_types(type1=torch.int64, type2=torch.float32)
# torch.float32

torch.promote_types(type1=torch.float32, type2=torch.complex64)
torch.promote_types(type1=torch.complex32, type2=torch.complex64)
# torch.complex64

torch.promote_types(type1=torch.float32, type2=torch.float64)
torch.promote_types(type1=int, type2=float)
# torch.float64
Enter fullscreen mode Exit fullscreen mode

can_cast() can check the dtype's category(from) is promotable to the dtype'category(to), getting a boolean value as shown below:

*Memos:

  • can_cast() can be used with torch but not with a tensor.
  • The 1st argument with torch is from(Required-Type:dtype). *It must be without from=.
  • The 2nd argument with torch is to(Required-Type:dtype).
  • from and to can also accept int(), float() and bool() but not complex() which are python built-in functions.
import torch

torch.can_cast(torch.bool, to=torch.int64)
torch.can_cast(bool, to=torch.int64)
torch.can_cast(torch.bool, to=int)
torch.can_cast(bool, to=int)
torch.can_cast(torch.int64, to=torch.float32)
torch.can_cast(torch.float32, to=torch.complex64)
torch.can_cast(torch.int32, to=torch.int64)
torch.can_cast(torch.int64, to=torch.int32)
torch.can_cast(torch.float32, to=torch.float64)
torch.can_cast(torch.float64, to=torch.float32)
torch.can_cast(torch.complex32, to=torch.complex64)
torch.can_cast(torch.complex64, to=torch.complex32)
# True

torch.can_cast(torch.int64, to=torch.bool)
torch.can_cast(int, to=torch.bool)
torch.can_cast(torch.int64, to=bool)
torch.can_cast(int, to=bool)
torch.can_cast(torch.float32, to=torch.int64)
torch.can_cast(torch.complex64, to=torch.float32)
# False
Enter fullscreen mode Exit fullscreen mode

Top comments (0)