Difference between "detach()" and "with torch.nograd()" in PyTorch?

PythonPytorchAutograd

Python Problem Overview


I know about two ways to exclude elements of a computation from the gradient calculation backward

Method 1: using with torch.no_grad()

with torch.no_grad():
    y = reward + gamma * torch.max(net.forward(x))
loss = criterion(net.forward(torch.from_numpy(o)), y)
loss.backward();

Method 2: using .detach()

y = reward + gamma * torch.max(net.forward(x))
loss = criterion(net.forward(torch.from_numpy(o)), y.detach())
loss.backward();

Is there a difference between these two? Are there benefits/downsides to either?

Python Solutions


Solution 1 - Python

tensor.detach() creates a tensor that shares storage with tensor that does not require grad. It detaches the output from the computational graph. So no gradient will be backpropagated along this variable.

The wrapper with torch.no_grad() temporarily set all the requires_grad flag to false. torch.no_grad says that no operation should build the graph.

The difference is that one refers to only a given variable on which it is called. The other affects all operations taking place within the with statement. Also, torch.no_grad will use less memory because it knows from the beginning that no gradients are needed so it doesn’t need to keep intermediary results.

Learn more about the differences between these along with examples from here.

Solution 2 - Python

detach()

One example without detach():

from torchviz import make_dot
x=torch.ones(2, requires_grad=True)
y=2*x
z=3+x
r=(y+z).sum()    
make_dot(r)

enter image description here

The end result in green r is a root of the AD computational graph and in blue is the leaf tensor.

Another example with detach():

from torchviz import make_dot
x=torch.ones(2, requires_grad=True)
y=2*x
z=3+x.detach()
r=(y+z).sum()    
make_dot(r)

enter image description here

This is the same as:

from torchviz import make_dot
x=torch.ones(2, requires_grad=True)
y=2*x
z=3+x.data
r=(y+z).sum()    
make_dot(r)

But, x.data is the old way (notation), and x.detach() is the new way.

What is the difference with x.detach()

print(x)
print(x.detach())

Out:

tensor([1., 1.], requires_grad=True)
tensor([1., 1.])

So x.detach() is a way to remove requires_grad and what you get is a new detached tensor (detached from AD computational graph).

torch.no_grad

torch.no_grad is actually a class.

x=torch.ones(2, requires_grad=True)
with torch.no_grad():
    y = x * 2
print(y.requires_grad)

Out:

False

From help(torch.no_grad):

>Disabling gradient calculation is useful for inference, when you are sure | that you will not call :meth:Tensor.backward(). It will reduce memory | consumption for computations that would otherwise have requires_grad=True. |
| In this mode, the result of every computation will have | requires_grad=False, even when the inputs have requires_grad=True.

Solution 3 - Python

A simple and profound explanation is that use of with torch.no_grad() behaves just like a loop where everything written in it will have there requires_grad argument set as False although temporarily. So there is no need to specify anything beyond this if you need to stop the backpropagation from gradients of certain variables or functions.

However, torch.detach() simply detaches the variable from the gradient computation graph as the name suggests. But this is used when this specification has to be provided for a limited number of variables or functions for eg. generally while displaying the loss and accuracy outputs after an epoch ends in neural network training because at that moment, it only consumed resourced since its gradient won't matter in during the display of results.

Attributions

All content for this solution is sourced from the original question on Stackoverflow.

The content on this page is licensed under the Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.

Content TypeOriginal AuthorOriginal Content on Stackoverflow
Questionuser637140View Question on Stackoverflow
Solution 1 - PythonAnubhav SinghView Answer on Stackoverflow
Solution 2 - PythonprostiView Answer on Stackoverflow
Solution 3 - PythonSHAGUN SHARMAView Answer on Stackoverflow