1. 仅保存权重信息

# 模型路径
path = "state_dict_model.pt"

# 保存
torch.save(model.state_dict(), path)

# 加载
model = Network()
# 将训练好的权重加载到模型中
model.load_state_dict(torch.load(path))

2. 保存全部信息

# 对整个模型进保存和加载
path = "entire_model.pt"

# 保存模型
torch.save(model, path)

# 加载模型
model = torch.load(path)

3. 保存checkpoint

# 保存checkpoint
path = 'model.pt'
torch.save(
    {
        'epoch':epochs,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict':optimizer.state_dict(),
        'loss': loss_fn
    },path
)

# 加载
model = Network(input_num)
optimizer = optimizer = torch.optim.SGD(model.parameters(), lr=lr)

checkpoint = torch.load(path)
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
epoch = checkpoint["epoch"]
loss = checkpoint["loss"]

4. 其他测试

当我们打印模型时:

net = MyModel(3)

print(net.state_dict().items()) # 输出模型每一层的权重
print(net.state_dict())   # 输出模型每一层的权重