DEV Community

theroyakash
theroyakash

Posted on

What is PyTorch `.detach()` method?

What is PyTorch .detach() method?

PyTorch's detach method works on the tensor class.

tensor.detach() creates a tensor that shares storage with tensor that does not require gradient. tensor.clone() creates a copy of tensor that imitates the original tensor's requires_grad field.

You should use detach() when attempting to remove a tensor from a computation graph, and clone as a way to copy the tensor while still keeping the copy as a part of the computation graph it came from.

Let's see that in an example here

X = torch.ones((28, 28), dtype=torch.float32, requires_grad=True)
y = X**2
z = X**2

result = (y+z).sum()

torchviz.make_dot(result).render('Attached', format='png')
Enter fullscreen mode Exit fullscreen mode

1.png

And now one with the detach.

X = torch.ones((28, 28), dtype=torch.float32, requires_grad=True)
y = X**2
z = X.detach()**2

result = (y+z).sum()

torchviz.make_dot(result).render('Attached', format='png')
Enter fullscreen mode Exit fullscreen mode

Screen Shot 2020-10-22 at 8.48.43 PM.png

As you can see now that the branch of computation with x**2 is no longer tracked. This is reflected in the gradient of the result which no longer records the contribution of this branch

Top comments (0)