VB.net 2010 视频教程 VB.net 2010 视频教程 python基础视频教程
SQL Server 2008 视频教程 c#入门经典教程 Visual Basic从门到精通视频教程
当前位置:
首页 > temp > python入门教程 >
  • pytorch 中 torch.no_grad()、requires_grad、eval()

requires_grad

requires_grad=True 要求计算梯度;
requires_grad=False 不要求计算梯度;
在pytorch中,tensor有一个 requires_grad参数,如果设置为True,则反向传播时,该tensor就会自动求导。 tensor的requires_grad的属性默认为False,若一个节点(叶子变量:自己创建的tensor)requires_grad被设置为True,那么 所有依赖它的节点requires_grad都为True (即使其他相依赖的tensor的requires_grad = False)

x = torch.randn(10, 5, requires_grad = True)
y = torch.randn(10, 5, requires_grad = False)
z = torch.randn(10, 5, requires_grad = False)
w = x + y + z
w.requires_grad

输出:

True

volatile

volatile是Variable的另一个重要的标识,它能够将所有依赖它的节点全部设为volatile=True,优先级比requires_grad=True高。
而volatile=True的节点不会求导,即使requires_grad=True,也不会进行反向传播,对于不需要反向传播的情景(inference,测试阶段推断阶段),该参数可以实现一定速度的提升,并节省一半的显存,因为其不需要保存梯度。
但是, 注意 volatile已经取消了,使用with torch.no_grad()来替代 。

torch.no_grad()

是一个上下文管理器,被该语句 wrap 起来的部分将不会track 梯度。
with torch.no_grad()或者@torch.no_grad()中的数据不需要计算梯度,也不会进行反向传播。
(torch.no_grad()是新版本pytorch中volatile的替代)

x = torch.randn(2, 3, requires_grad = True)
y = torch.randn(2, 3, requires_grad = False)
z = torch.randn(2, 3, requires_grad = False)
m=x+y+z
with torch.no_grad():
    w = x + y + z
print(w)
print(m)
print(w.requires_grad)
print(w.grad_fn)
print(w.requires_grad)

输出:

tensor([[-2.7066, -0.7406,  0.5740],
        [-0.7071, -1.6057,  1.9732]])
tensor([[-2.7066, -0.7406,  0.5740],
        [-0.7071, -1.6057,  1.9732]], grad_fn=<AddBackward0>)
False
None
False

model.eval()与with torch.no_grad()

共同点:

在PyTorch中进行validation时,使用这两者均可切换到测试模式。

如用于通知dropout层和batchnorm层在train和val模式间切换。
在train模式下,dropout网络层会按照设定的参数p设置保留激活单元的概率(保留概率=p); batchnorm层会继续计算数据的mean和var等参数并更新。
在val模式下,dropout层会让所有的激活单元都通过,而batchnorm层会停止计算和更新mean和var,直接使用在训练阶段已经学出的mean和var值。

不同点:

model.eval()会影响各层的gradient计算行为,即gradient计算和存储与training模式一样,只是不进行反传。

with torch.zero_grad()则停止autograd模块的工作,也就是停止gradient计算,以起到加速和节省显存的作用,从而节省了GPU算力和显存,但是并不会影响dropout和batchnorm层的行为。

也就是说,如果不在意显存大小和计算时间的话,仅使用model.eval()已足够得到正确的validation的结果;而with torch.zero_grad()则是更进一步加速和节省gpu空间(因为不用计算和存储gradient),从而可以更快计算,也可以跑更大的batch来测试。
 

出处:https://www.cnblogs.com/dychen/p/13921791.html


相关教程