PyTorch 学习笔记 (9): 模型管理
2025-12-20·10 min read
#PyTorch#Deep Learning#Model Management
模型管理是深度学习项目中的必备技能!
主要场景:
- 训练中断后恢复(checkpoint)
- 模型部署和推理
- 迁移学习(使用预训练模型)
- 模型版本管理
模型保存的三种方式
方式1:只保存模型参数(推荐)
python
import torch
import torch.nn as nn
import torch.optim as optim
# 定义模型
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 20)
self.fc2 = nn.Linear(20, 10)
self.relu = nn.ReLU()
def forward(self, x):
x = self.relu(self.fc1(x))
x = self.fc2(x)
return x
model = SimpleNet()
# 保存模型参数
torch.save(model.state_dict(), 'model_weights.pth')
优点:文件小,灵活性高,最常用
方式2:保存整个模型
python
torch.save(model, 'model_whole.pth')
注意:可能有兼容性问题,不推荐用于生产环境
方式3:保存 checkpoint(训练状态)
python
optimizer = optim.Adam(model.parameters(), lr=0.001)
epoch = 50
loss = 0.123
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
}
torch.save(checkpoint, 'checkpoint.pth')
用途:保存完整训练状态,用于中断后恢复训练
加载模型
加载模型参数
python
# 必须先定义模型结构
loaded_model = SimpleNet()
# 加载参数
state_dict = torch.load('model_weights.pth')
loaded_model.load_state_dict(state_dict)
loaded_model.eval() # 设置为评估模式
从 checkpoint 加载
python
new_model = SimpleNet()
new_optimizer = optim.Adam(new_model.parameters(), lr=0.001)
# 加载 checkpoint
checkpoint = torch.load('checkpoint.pth')
new_model.load_state_dict(checkpoint['model_state_dict'])
new_optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch'] + 1
模型推理
推理时的重要注意事项
- 设置为评估模式:
model.eval()- 禁用 Dropout 和 BatchNorm 的训练行为 - 使用 torch.no_grad():不计算梯度,节省内存
- 处理输入数据:转换为 Tensor,添加 batch 维度
- 处理输出:移回 CPU,转换为 numpy
标准推理函数
python
def predict(model, input_data, device='cpu'):
"""标准的推理函数"""
model.eval() # 设置为评估模式
# 处理输入
if isinstance(input_data, np.ndarray):
input_data = torch.tensor(input_data, dtype=torch.float32)
# 添加 batch 维度(如果需要)
if input_data.dim() == 1:
input_data = input_data.unsqueeze(0)
# 移动到设备
input_data = input_data.to(device)
model = model.to(device)
# 推理
with torch.no_grad():
output = model(input_data)
# 移回 CPU 并转换为 numpy
output = output.cpu().numpy()
return output
批量推理
python
def batch_predict(model, inputs, batch_size=32, device='cpu'):
"""批量推理"""
model.eval()
predictions = []
with torch.no_grad():
for i in range(0, len(inputs), batch_size):
batch = inputs[i:i+batch_size].to(device)
output = model(batch)
predictions.append(output.cpu())
return torch.cat(predictions, dim=0)
完整的训练循环(带 checkpoint)
python
import os
from datetime import datetime
def train_with_checkpointing(model, train_loader, val_loader, epochs, save_dir='checkpoints'):
"""带 checkpoint 保存的训练循环"""
os.makedirs(save_dir, exist_ok=True)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.MSELoss()
best_val_loss = float('inf')
for epoch in range(epochs):
# 训练
model.train()
for batch_x, batch_y in train_loader:
optimizer.zero_grad()
output = model(batch_x)
loss = criterion(output, batch_y)
loss.backward()
optimizer.step()
# 验证
model.eval()
val_loss = 0.0
with torch.no_grad():
for batch_x, batch_y in val_loader:
output = model(batch_x)
val_loss += criterion(output, batch_y).item()
val_loss /= len(val_loader)
# 定期保存 checkpoint
if (epoch + 1) % 10 == 0:
checkpoint = {
'epoch': epoch + 1,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'val_loss': val_loss,
'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
}
torch.save(checkpoint, f'{save_dir}/checkpoint_epoch{epoch+1}.pth')
# 保存最佳模型
if val_loss < best_val_loss:
best_val_loss = val_loss
torch.save({
'model_state_dict': model.state_dict(),
'val_loss': val_loss,
'epoch': epoch + 1
}, f'{save_dir}/best_model.pth')
最佳实践
| 场景 | 推荐方法 |
|---|---|
| 保存最终模型 | state_dict() + model_config.json |
| 训练中断恢复 | 完整 checkpoint(包含优化器状态) |
| 部署到生产 | ONNX + TorchScript |
| 模型版本管理 | 文件名包含时间戳或版本号 |
文件命名规范
text
model_weights.pth # 模型参数
model_config.json # 模型配置
checkpoint_epoch{N}.pth # 定期保存
best_model.pth # 最佳模型
model_YYYYMMDD_HHMMSS.pth # 时间戳
重要提醒
- 始终保存模型配置(超参数)
- 定期保存 checkpoint(防训练中断)
- 推理前设置
model.eval() - 推理时使用
torch.no_grad() - 测试加载函数(确保能正确恢复)
总结
模型管理是深度学习工程化的关键:
- 保存:使用
state_dict()保存参数 - 加载:先定义模型结构,再加载参数
- checkpoint:保存完整训练状态
- 推理:
eval()+no_grad()