DEV Community

Super Kai (Kazuya Ito)
Super Kai (Kazuya Ito)

Posted on • Edited on

Type promotion, result_type(), promote_types() and can_cast() in PyTorch

Buy Me a Coffee

*My post explains how to create and acceess a tensor.

Arithmetic operation like (addition, subtraction, multiplication, etc) has type promotion in PyTorch as shown below:

*Memos:

  • The priority of type promotion is complex <- float <- int <- bool.
  • If types' category is the same but the size is different, the higher size's one is taken. (e.g. int32 and int64)
import torch

tensor1 = torch.tensor([True, False, True], dtype=torch.bool)
tensor2 = torch.tensor([3, 4, 5], dtype=torch.int64)

torch.add(input=tensor1, other=tensor2)
# tensor([4, 4, 6])

torch.add(input=tensor1, other=tensor2).dtype
# torch.int64

tensor1 = torch.tensor([0, 1, 2], dtype=torch.int64)
tensor2 = torch.tensor([3., 4., 5.], dtype=torch.float32)

torch.add(input=tensor1, other=tensor2)
# tensor([3., 5., 7.])

torch.add(input=tensor1, other=tensor2).dtype
# torch.float32

tensor1 = torch.tensor([0., 1., 2.], dtype=torch.float32)
tensor2 = torch.tensor([3.+0.j, 4.+0.j, 5.+0.j], 
                       dtype=torch.complex64)
torch.add(input=tensor1, other=tensor2)
# tensor([3.+0.j, 5.+0.j, 7.+0.j])

torch.add(input=tensor1, other=tensor2).dtype
# torch.complex64

tensor1 = torch.tensor([0, 1, 2], dtype=torch.int32)
tensor2 = torch.tensor([3, 4, 5], dtype=torch.int64)

torch.add(input=tensor1, other=tensor2)
# tensor([3, 5, 7])

torch.add(input=tensor1, other=tensor2).dtype
# torch.int64

tensor1 = torch.tensor([0., 1., 2.], dtype=torch.float32)
tensor2 = torch.tensor([3., 4., 5.], dtype=torch.float64)

torch.add(input=tensor1, other=tensor2)
# tensor([3., 5., 7.], dtype=torch.float64)

torch.add(input=tensor1, other=tensor2).dtype
# torch.float64

tensor1 = torch.tensor([0.+0.j, 1.+0.j, 2.+0.j], 
                       dtype=torch.complex32)
tensor2 = torch.tensor([3.+0.j, 4.+0.j, 5.+0.j], 
                       dtype=torch.complex64)
torch.add(input=tensor1, other=tensor2)
# tensor([3.+0.j, 5.+0.j, 7.+0.j])

torch.add(input=tensor1, other=tensor2).dtype
# torch.complex64
Enter fullscreen mode Exit fullscreen mode

result_type() can check the result type of two of the 0D or more D tersors of zero or more elements or scalars or the 0D or more D tensor of zero or more elements and a scalar, getting a dtype as shown below:

*Memos:

  • result_type() can be used with torch but not with a tensor.
  • The 1st argument with torch is tensor(Required-Type:tensor or scalar of int, float, complex or bool). *A scalar must used without tensor=.
  • The 2nd argument with torch is other(Required-Type:tensor or scalar of int, float, complex or bool).
import torch

tensor1 = torch.tensor([True, False, True], dtype=torch.bool)
tensor2 = torch.tensor([3, 4, 5], dtype=torch.int64)

torch.result_type(tensor=tensor1, other=tensor2)
torch.result_type(tensor=True, other=tensor2)
torch.result_type(tensor=tensor1, other=3)
torch.result_type(True, 3)
# torch.int64

tensor1 = torch.tensor([0, 1, 2], dtype=torch.int64)
tensor2 = torch.tensor([3., 4., 5.], dtype=torch.float32)

torch.result_type(tensor=tensor1, other=tensor2)
# torch.float32

tensor1 = torch.tensor([0., 1., 2.], dtype=torch.float32)
tensor2 = torch.tensor([3.+0.j, 4.+0.j, 5.+0.j], 
                       dtype=torch.complex64)
torch.result_type(tensor=tensor1, other=tensor2)
# torch.complex64

tensor1 = torch.tensor([0, 1, 2], dtype=torch.int32)
tensor2 = torch.tensor([3, 4, 5], dtype=torch.int64)

torch.result_type(tensor=tensor1, other=tensor2)
# torch.int64

tensor1 = torch.tensor([0., 1., 2.], dtype=torch.float32)
tensor2 = torch.tensor([3., 4., 5.], dtype=torch.float64)

torch.result_type(tensor=tensor1, other=tensor2)
# torch.float64

tensor1 = torch.tensor([0.+0.j, 1.+0.j, 2.+0.j], 
                       dtype=torch.complex32)
tensor2 = torch.tensor([3.+0.j, 4.+0.j, 5.+0.j], 
                       dtype=torch.complex64)
torch.result_type(tensor=tensor1, other=tensor2)
# torch.complex64
Enter fullscreen mode Exit fullscreen mode

promote_types() can check the result type of two of dtypes, getting a dtype as shown below:

*Memos:

  • promote_types() can be used with torch but not with a tensor.
  • The 1st argument with torch is type1(Required-Type:dtype).
  • The 2nd argument with torch is type2(Required-Type:dtype).
  • type1 and type2 can also accept int(), float() and bool() but not complex() which are python built-in functions.
import torch

torch.promote_types(type1=torch.bool, type2=torch.int64)
torch.promote_types(type1=torch.bool, type2=int)
torch.promote_types(type1=bool, type2=torch.int64)
torch.promote_types(type1=bool, type2=int)
torch.promote_types(type1=torch.int32, type2=torch.int64)
# torch.int64

torch.promote_types(type1=torch.int64, type2=torch.float32)
# torch.float32

torch.promote_types(type1=torch.float32, type2=torch.complex64)
torch.promote_types(type1=torch.complex32, type2=torch.complex64)
# torch.complex64

torch.promote_types(type1=torch.float32, type2=torch.float64)
torch.promote_types(type1=int, type2=float)
# torch.float64
Enter fullscreen mode Exit fullscreen mode

can_cast() can check the dtype's category(from) is promotable to the dtype'category(to), getting a boolean value as shown below:

*Memos:

  • can_cast() can be used with torch but not with a tensor.
  • The 1st argument with torch is from(Required-Type:dtype). *It must be without from=.
  • The 2nd argument with torch is to(Required-Type:dtype).
  • from and to can also accept int(), float() and bool() but not complex() which are python built-in functions.
import torch

torch.can_cast(torch.bool, to=torch.int64)
torch.can_cast(bool, to=torch.int64)
torch.can_cast(torch.bool, to=int)
torch.can_cast(bool, to=int)
torch.can_cast(torch.int64, to=torch.float32)
torch.can_cast(torch.float32, to=torch.complex64)
torch.can_cast(torch.int32, to=torch.int64)
torch.can_cast(torch.int64, to=torch.int32)
torch.can_cast(torch.float32, to=torch.float64)
torch.can_cast(torch.float64, to=torch.float32)
torch.can_cast(torch.complex32, to=torch.complex64)
torch.can_cast(torch.complex64, to=torch.complex32)
# True

torch.can_cast(torch.int64, to=torch.bool)
torch.can_cast(int, to=torch.bool)
torch.can_cast(torch.int64, to=bool)
torch.can_cast(int, to=bool)
torch.can_cast(torch.float32, to=torch.int64)
torch.can_cast(torch.complex64, to=torch.float32)
# False
Enter fullscreen mode Exit fullscreen mode

Top comments (0)

Image of Docusign

🛠️ Bring your solution into Docusign. Reach over 1.6M customers.

Docusign is now extensible. Overcome challenges with disconnected products and inaccessible data by bringing your solutions into Docusign and publishing to 1.6M customers in the App Center.

Learn more