DEV Community

Super Kai (Kazuya Ito)
Super Kai (Kazuya Ito)

Posted on

Set and get `dtype` in PyTorch

Buy Me a Coffee

*Memos:

You can set and get dtype as shown below:

*Memos:

tensor(). *My post explains tensor():

import torch

my_tensor = torch.tensor(data=[0, 1, 2], dtype=torch.float64)
my_tensor = torch.tensor(data=[0, 1, 2], dtype=float)

my_tensor, my_tensor.dtype, my_tensor.type()
# (tensor([0., 1., 2.],
#  dtype=torch.float64), torch.float64,
#  'torch.DoubleTensor')

my_tensor = torch.tensor(data=[0, 1, 2], dtype=torch.complex64)

my_tensor, my_tensor.dtype, my_tensor.type()
# (tensor([0.+0.j, 1.+0.j, 2.+0.j]),
#  torch.complex64,
#  'torch.ComplexFloatTensor')

my_tensor = torch.tensor(data=[0, 1, 2], dtype=torch.bool)
my_tensor = torch.tensor(data=[0, 1, 2], dtype=bool)

my_tensor, my_tensor.dtype, my_tensor.type()
# (tensor([False, True, True]), torch.bool, 'torch.BoolTensor')

my_tensor = torch.tensor(data=[0., 1., 2.], dtype=torch.int64)
my_tensor = torch.tensor(data=[0., 1., 2.], dtype=int)

my_tensor, my_tensor.dtype, my_tensor.type()
# (tensor([0, 1, 2]), torch.int64, 'torch.LongTensor')

my_tensor = torch.tensor(data=[0., 1., 2.], dtype=torch.complex64)

my_tensor, my_tensor.dtype, my_tensor.type()
# (tensor([0.+0.j, 1.+0.j, 2.+0.j]),
#  torch.complex64,
#  'torch.ComplexFloatTensor')

my_tensor = torch.tensor(data=[0., 1., 2.], dtype=torch.bool)
my_tensor = torch.tensor(data=[0., 1., 2.], dtype=bool)

my_tensor, my_tensor.dtype, my_tensor.type()
# (tensor([False, True, True]), torch.bool, 'torch.BoolTensor')

my_tensor = torch.tensor(data=[0.+0.j, 0.+7.j, 2.+0.j], dtype=torch.bool)
my_tensor = torch.tensor(data=[0.+0.j, 0.+7.j, 2.+0.j], dtype=bool)

my_tensor, my_tensor.dtype, my_tensor.type()
# (tensor([False, True, True]), torch.bool, 'torch.BoolTensor')

my_tensor = torch.tensor(data=[True, False, True], dtype=torch.int64)
my_tensor = torch.tensor(data=[True, False, True], dtype=int)

my_tensor, my_tensor.dtype, my_tensor.type()
# (tensor([1, 0, 1]), torch.int64, 'torch.LongTensor')

my_tensor = torch.tensor(data=[True, False, True], dtype=torch.float64)
my_tensor = torch.tensor(data=[True, False, True], dtype=float)

my_tensor, my_tensor.dtype, my_tensor.type()
# (tensor([1., 0., 1.], dtype=torch.float64),
#  torch.float64,
#  'torch.DoubleTensor')

my_tensor = torch.tensor(data=[True, False, True], dtype=torch.complex64)

my_tensor, my_tensor.dtype, my_tensor.type()
# (tensor([1.+0.j, 0.+0.j, 1.+0.j]),
#  torch.complex64,
#  'torch.ComplexFloatTensor')
Enter fullscreen mode Exit fullscreen mode

arange(). *My post explains arange():

import torch

my_tensor = torch.arange(start=5, end=15, step=3, dtype=torch.float64)

my_tensor, my_tensor.dtype, my_tensor.type()
# (tensor([5., 8., 11., 14.], dtype=torch.float64),
#  torch.float64,
#  'torch.DoubleTensor')
Enter fullscreen mode Exit fullscreen mode

rand(). *My post explains rand():

import torch

my_tensor = torch.rand(size=(3,), dtype=torch.float64)

my_tensor, my_tensor.dtype, my_tensor.type()
# (tensor([0.4620, 0.6369, 0.5189], dtype=torch.float64),
#  torch.float64,
#  'torch.DoubleTensor')
Enter fullscreen mode Exit fullscreen mode

rand_like(). *My post explains rand_like():

import torch

my_tensor = torch.rand_like(input=torch.tensor([7., 4., 5.]), 
                            dtype=torch.float64)
my_tensor, my_tensor.dtype, my_tensor.type()
# (tensor([0.7677, 0.2914, 0.3266], dtype=torch.float64),
#  torch.float64,
#  'torch.DoubleTensor')
Enter fullscreen mode Exit fullscreen mode

sum(). *My post explains sum():

import torch

my_tensor = torch.sum(input=torch.tensor([0., 1., 2., 3.]),
                      dtype=torch.float64)
my_tensor, my_tensor.dtype, my_tensor.type()
# (tensor(6., dtype=torch.float64), torch.float64, 'torch.DoubleTensor')
Enter fullscreen mode Exit fullscreen mode

view(). *My post explains view():

import torch

my_tensor1 = torch.tensor([0., 1., 2.]).view(size=(3, 1))

my_tensor2 = my_tensor.view(dtype=torch.bool)

my_tensor1, my_tensor2, my_tensor.dtype, my_tensor.type()
# (tensor([[0.],
#          [1.],
#          [2.]]),
#  tensor([[False, False, False, False],
#          [False, False,  True,  True],
#          [False, False, False,  True]]),
#  torch.bool,
#  'torch.BoolTensor')
Enter fullscreen mode Exit fullscreen mode

Top comments (0)