PyTorch Notes

Posted by saltyfishyjk on 2023-08-04
Words 413 and Reading Time 1 Minutes
Viewed Times

PyTorch Notes

Part 0 简介

PyTorch 是目前最流行的深度学习框架之一。

参考资料:

Part 1

with torch.no_grad()

PyTorch 中的一个上下文管理器(context manager),提供一个临时的运行环境,在该环境中,所有 PyTorch 的操作不会被自动求导(AutoGrad),即,不会对计算图进行记录,从而节省内存并提高运行速度。

具体地,在深度学习中,训练阶段往往需要计算模型的梯度并使用梯度更新模型的参数,以进行反向传播和优化;在测试阶段,通常只需要进行前向传播以获取模型的输出,而不需要再计算梯度,因此,此时使用 with torch.no_grad() 可以节约资源,避免不必要的计算。

e.g.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
lr = 0.03
num_epochs = 3
net = linreg
loss = squared_loss

for epoch in range(num_epochs):
for X, y in data_iter(batch_size, features, labels):
l = loss(net(X, w, b), y) # X和y的小批量损失
# 因为l形状是(batch_size,1),而不是一个标量。l中的所有元素被加到一起,
# 并以此计算关于[w,b]的梯度
l.sum().backward()
sgd([w, b], lr, batch_size) # 使用参数的梯度更新参数
with torch.no_grad():
train_l = loss(net(features, w, b), labels)
print(f'epoch {epoch + 1}, loss {float(train_l.mean()):f}')

该代码片段来自d2l,其中,with torch.no_grad() 环境下进行 loss 的计算,用以直观表现每一个 epoch 后的优化效果(即应当观察到 loss 越来越小),此时并非要进行训练,因此只进行前向传播和计算(预测)即可。


This is copyright.