DEV Community

Super Kai (Kazuya Ito)
Super Kai (Kazuya Ito)

Posted on • Edited on

requires_grad=True with a tensor, backward() and retain_grad() in PyTorch

Buy Me a Coffee

*Memos:

requires_grad(Optional-Default:False-Type:bool) with True can enable a tensor to compute and accumulate its gradient as shown below:

*Memos:

  • There are a leaf tensor and non-leaf tensor.
  • data must be float or complex type with requires_grad=True.
  • backward() can do backpropagation. *Backpropagation is to calculate a gradient using the mean(average) of the sum of the losses(differences) between the model's predictions and true values(train data), working from output layer to input layer.
  • A gradient is accumulated each time backward() is called.
  • To call backward():
    • requires_grad must be True.
    • data must be the scalar(only one element) of float type of the 0D or more D tensor.
  • grad can get a gradient.
  • is_leaf can check if it's a leaf tensor or non-leaf tensor.
  • To call retain_grad(), requires_grad must be True.
  • To enable a non-leaf tensor to get a gradient without a warning using grad, retain_grad() must be called before it
  • Using retain_graph=True with backward() prevents error.

1 tensor with backward():

import torch

my_tensor = torch.tensor(data=7., requires_grad=True) # Leaf tensor

my_tensor, my_tensor.grad, my_tensor.is_leaf
# (tensor(7., requires_grad=True), None, True)

my_tensor.backward()

my_tensor, my_tensor.grad, my_tensor.is_leaf
# (tensor(7., requires_grad=True), tensor(1.), True)

my_tensor.backward()

my_tensor, my_tensor.grad, my_tensor.is_leaf
# (tensor(7., requires_grad=True), tensor(2.), True)

my_tensor.backward()

my_tensor, my_tensor.grad, my_tensor.is_leaf
# (tensor(7., requires_grad=True), tensor(3.), True)
Enter fullscreen mode Exit fullscreen mode

3 tensors with backward(retain_graph=True) and retain_grad():

import torch

tensor1 = torch.tensor(data=7., requires_grad=True) # Leaf tensor

tensor1, tensor1.grad, tensor1.is_leaf
# (tensor(7., requires_grad=True), None, True)

tensor1.backward()

tensor1, tensor1.grad, tensor1.is_leaf
# (tensor(7., requires_grad=True), tensor(1.), True)

tensor2 = tensor1 * 4 # Non-leaf tensor

tensor2.retain_grad()

tensor1, tensor1.grad, tensor1.is_leaf
# (tensor(7., requires_grad=True), tensor(1.), True)

tensor2, tensor2.grad, tensor2.is_leaf
# (tensor(28., grad_fn=<MulBackward0>), None, False)

tensor2.backward(retain_graph=True) # Important

tensor1, tensor1.grad, tensor1.is_leaf
# (tensor(7., requires_grad=True), tensor(5.), True)

tensor2, tensor2.grad, tensor2.is_leaf
# (tensor(28., grad_fn=<MulBackward0>), tensor(1.), False)

tensor3 = tensor2 * 5 # Non-leaf tensor

tensor3.retain_grad()

tensor1, tensor1.grad, tensor1.is_leaf
# (tensor(7., requires_grad=True), tensor(5.), True)

tensor2, tensor2.grad, tensor2.is_leaf
# (tensor(28., grad_fn=<MulBackward0>), tensor(1.), False)

tensor3, tensor3.grad, tensor3.is_leaf
# (tensor(140., grad_fn=<MulBackward0>), None, False)

tensor3.backward()

tensor1, tensor1.grad, tensor1.is_leaf
# (tensor(7., requires_grad=True), tensor(25.), True)

tensor2, tensor2.grad, tensor2.is_leaf
# (tensor(28., grad_fn=<MulBackward0>), tensor(6.), False)

tensor3, tensor3.grad, tensor3.is_leaf
# (tensor(140., grad_fn=<MulBackward0>), tensor(1.), False)
Enter fullscreen mode Exit fullscreen mode

In addition, 3 tensors with detach_() and requires_grad_(requires_grad=True) which doesn't retain gradients:

import torch

tensor1 = torch.tensor(data=7., requires_grad=True) # Leaf tensor

tensor1, tensor1.grad, tensor1.is_leaf
# (tensor(7., requires_grad=True), None, True)

tensor1.backward()

tensor1, tensor1.grad, tensor1.is_leaf
# (tensor(7., requires_grad=True), tensor(1.), True)

tensor2 = tensor1 * 4 # Non-leaf tensor

tensor2.retain_grad()

tensor1, tensor1.grad, tensor1.is_leaf
# (tensor(7., requires_grad=True), tensor(1.), True)

tensor2, tensor2.grad, tensor2.is_leaf
# (tensor(28., grad_fn=<MulBackward0>), None, False)

tensor2.backward()

tensor1, tensor1.grad, tensor1.is_leaf
# (tensor(7., requires_grad=True), tensor(5.), True)

tensor2, tensor2.grad, tensor2.is_leaf
# (tensor(28., grad_fn=<MulBackward0>), tensor(1.), False)

tensor3 = tensor2 * 5 # Non-leaf tensor
tensor3 = tensor3.detach_().requires_grad_(requires_grad=True) # Leaf tensor
                 # Important
tensor3.retain_grad()

tensor1, tensor1.grad, tensor1.is_leaf
# (tensor(7., requires_grad=True), tensor(5.), True)

tensor2, tensor2.grad, tensor2.is_leaf
# (tensor(28., grad_fn=<MulBackward0>), tensor(1.), False)

tensor3, tensor3.grad, tensor3.is_leaf
# (tensor(140., requires_grad=True), None, True)

tensor3.backward()

tensor1, tensor1.grad, tensor1.is_leaf
# (tensor(7., requires_grad=True), tensor(5.), True)

tensor2, tensor2.grad, tensor2.is_leaf
# (tensor(28., grad_fn=<MulBackward0>), tensor(1.), False)

tensor3, tensor3.grad, tensor3.is_leaf
# (tensor(140., requires_grad=True), tensor(1.), True)
Enter fullscreen mode Exit fullscreen mode

In addtion, you can manually set a gradient to a tensor whether requires_grad is True or False as shown below:
*Memos:

  • A gradient must be:
    • a tensor.
    • the same type and size as its tensor.

float:

import torch

my_tensor = torch.tensor(data=7., requires_grad=True)

my_tensor, my_tensor.grad, my_tensor.is_leaf
# (tensor(7., requires_grad=True), None, True)

my_tensor.grad = torch.tensor(data=4.)

my_tensor, my_tensor.grad, my_tensor.is_leaf
# (tensor(7., requires_grad=True), tensor(4.), True)

my_tensor = torch.tensor(data=7., requires_grad=False)

my_tensor, my_tensor.grad, my_tensor.is_leaf
# (tensor(7.), None, True)

my_tensor.grad = torch.tensor(data=4.)

my_tensor, my_tensor.grad, my_tensor.is_leaf
# (tensor(7.), tensor(4.), True)
Enter fullscreen mode Exit fullscreen mode

complex:

import torch

my_tensor = torch.tensor(data=7.+0.j, requires_grad=True)

my_tensor, my_tensor.grad, my_tensor.is_leaf
# (tensor(7.+0.j, requires_grad=True), None, True)

my_tensor.grad = torch.tensor(data=4.+0.j)

my_tensor, my_tensor.grad, my_tensor.is_leaf
# (tensor(7.+0.j, requires_grad=True), tensor(4.+0.j), True)

my_tensor = torch.tensor(data=7.+0.j, requires_grad=False)

my_tensor, my_tensor.grad, my_tensor.is_leaf
# (tensor(7.+0.j), None, True)

my_tensor.grad = torch.tensor(data=4.+0.j)

my_tensor, my_tensor.grad, my_tensor.is_leaf
# (tensor(7.+0.j), tensor(4.+0.j), True)
Enter fullscreen mode Exit fullscreen mode

Sentry image

See why 4M developers consider Sentry, “not bad.”

Fixing code doesn’t have to be the worst part of your day. Learn how Sentry can help.

Learn more

Top comments (0)

Billboard image

The Next Generation Developer Platform

Coherence is the first Platform-as-a-Service you can control. Unlike "black-box" platforms that are opinionated about the infra you can deploy, Coherence is powered by CNC, the open-source IaC framework, which offers limitless customization.

Learn more

👋 Kindness is contagious

Please leave a ❤️ or a friendly comment on this post if you found it helpful!

Okay