前言

从头训练一个网络是需要耗费大量算力的,主干网络往往采用一些特征提取能力很强的网络如Resnet50、HRNet w32等等。若自己对模型结构进行修改,就不能直接载入完整的预训练模型。

方法

方法一

载入权重之后,将最后一个全连接层的输出大小改为 5(默认是 1000 )。

import os
import torch
import torch.nn as nn
from model import resnet34

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model_weight_path = "./resnet34-pre.pth"
assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)

net = resnet34()
net.load_state_dict(torch.load(model_weight_path, map_location=device))
# change fc layer structure
in_channel = net.fc.in_features
net.fc = nn.Linear(in_channel, 5)

方法二

从预训练权重中删除最后全连接层的,只载入前面部分。

import os
import torch
import torch.nn as nn
from model import resnet34

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model_weight_path = "./resnet34-pre.pth"
assert os.path.exists(model_weight_path), "file {} does not exist.".format(
    model_weight_path)

net = resnet34(num_classes=5)
pre_weights = torch.load(model_weight_path, map_location=device)

# 找到 key 值中包含 fc 的,并从预训练权重中删除
del_key = []
for key, _ in pre_weights.items():
    if "fc" in key:
        del_key.append(key)

for key in del_key:
    del pre_weights[key]

# 返回缺失的值 keys 和不期望的值 keys
# strict=False 权重中的返回的键值不需要和网络模型中的完全匹配
missing_keys, unexpected_keys = net.load_state_dict(pre_weights, strict=False)
print("[missing_keys]:", *missing_keys, sep="\n")
print("[unexpected_keys]:", *unexpected_keys, sep="\n")

终端输出:

[missing_keys]:
fc.weight
fc.bias
[unexpected_keys]:

 


参考

代码参考来自:https://github.com/WZMIAOMIAO/deep-learning-for-image-processing/blob/master/pytorch_classification/Test5_resnet/load_weights.py

文章知识点与官方知识档案匹配,可进一步学习相关知识

Python入门技能树人工智能深度学习