Pytorch 修改、新增、刪除 pre-trained model layer

李謦伊
謦伊的閱讀筆記
17 min readJun 26, 2022

在做訓練時經常會使用 pre-trained model weight,但 pre-trained model 不一定完全符合我們的需要,因此會對其進行一些更動。本文將要來介紹用 pytorch 來進行修改、新增、刪除 pre-trained model layer 的方法,所有 code 會放在文章最下方。

首先 import 需要的 library

import torch
import torch.nn as nn
import torchvision.models as models
from collections import OrderedDict

pre-trained model weight 讀取並查看

可使用 pytorch 的 torchvision.models 中所提供的模型權重,也可以使用自己訓練或下載的模型權重檔。

  • 使用 pytorch 提供的 pre-trained model 權重

讀取模型時,設定 pretrained=True 會使用該模型的權重,接著可使用 pytorch 中的 state_dict 查看 layer 的名稱及參數。

state_dict 會以字典的方式來儲存,由下列結果可看到其類型為 collections.OrderedDict,key 值為 layer 的名稱、value 則是參數值。

model = models.resnet18(pretrained=True)
model_state = model.state_dict()
print("model_state type:", type(model_state))
for param_tensor in model_state():
print(“name:”, param_tensor)
print(“value:”, model_state[param_tensor])
# === output ===
model_state type: <class 'collections.OrderedDict'>
name: conv1.weight
value: tensor([[[[-1.0419e-02, -6.1356e-03, -1.8098e-03, ..., 5.6615e-02,
1.7083e-02, -1.2694e-02],
[ 1.1083e-02, 9.5276e-03, -1.0993e-01, ..., -2.7124e-01,
-1.2907e-01, 3.7424e-03],
...,
name: bn1.weight
value: tensor([ 2.3487e-01, 2.6626e-01, -5.1096e-08, 5.1870e-01, 3.4404e-09,
2.2239e-01, 4.2289e-01, 1.3153e-07, 2.5093e-01, 1.5152e-06,

...
  • 使用自己訓練或下載的權重檔

以 resnet18 為例,首先下載 resnet18 的權重檔

import gdownresnet_model = ‘https://download.pytorch.org/models/resnet18-5c106cde.pth'
gdown.download(resnet_model, “resnet-5c106cde.pth”)

讀取權重檔並查看其類型,由下列結果可看到其類型也是 collections.OrderedDict,key 值為 layer 的名稱、value 為參數值。

checkpoint = torch.load(‘resnet-5c106cde.pth’)
print("checkpoint type:", type(checkpoint))
for k, v in checkpoint.items():
print(“name:”, k)
print(“value:”, v)
# === output ===
checkpoint type: <class 'collections.OrderedDict'>
name: conv1.weight
value: Parameter containing:
tensor([[[[-1.0419e-02, -6.1356e-03, -1.8098e-03, ..., 5.6615e-02,
1.7083e-02, -1.2694e-02],
[ 1.1083e-02, 9.5276e-03, -1.0993e-01, ..., -2.7124e-01,
-1.2907e-01, 3.7424e-03],
...,
,requires_grad=True)
name: bn1.running_mean
value: tensor([ 2.7681e-03, -2.5769e-02, 2.1254e-07, -8.4605e-02, 2.1121e-08,
4.9691e-04, -2.2408e-02, -1.1582e-07, -4.8239e-03, 2.7507e-07,
...

修改 layer

這部分會分為修改 layer 的參數和名稱

🔖 參數

  • 修改最後一層 layer 的輸出值

在使用 pre-trained model 對自己的資料集進行 finetune 時,會需要將模型的最後一層輸出值改為自己的資料集類別數量。

以 resnet18 為例,最後一層 linear 層的輸出值為 1000。

model = models.resnet18(pretrained=True)
model.fc
# === output ===
Linear(in_features=512, out_features=1000, bias=True)

假設自己的資料集有 10 個類別,則將 linear 層的輸出改為該類別數量,由以下結果可看到輸出值變為 10。

in_features = model.fc.in_features
num_class = 10
model.fc = nn.Linear(in_features, num_class)
print(model.fc)
# === output ===
Linear(in_features=512, out_features=10, bias=True)

用 state_dict() 查看 linear 層的輸出也可看到 size 變成 [10, 512] 和 [10]

model_state = model.state_dict()
print(“weight: “, model_state[‘fc.weight’].shape)
print(“bias: “, model_state[‘fc.bias’].shape)
# === output ===
weight: torch.Size([10, 512])
bias: torch.Size([10])
  • 修改某層參數

由於權重檔讀取的類型為有序字典,因此可以透過 key 值來更改參數值。

這邊示範使用載入權重檔來做修改,若要看使用 pytorch 提供的 pre-trained model 權重的方法可看文章最下方的 code。假設要修改的是全連接層的 weight 及 bias,先來看一下該內部參數值。

checkpoint = torch.load(‘resnet-5c106cde.pth’)for k, v in checkpoint.items():
if k in [‘fc.weight’, ‘fc.bias’]:
print(“name:”, k)
print(“value:”, v)
# === output ===
name: fc.weight
value: Parameter containing:
tensor([[-0.0185, -0.0705, -0.0518, ..., -0.0390, 0.1735, -0.0410],
[-0.0818, -0.0944, 0.0174, ..., 0.2028, -0.0248, 0.0372],
...,
,requires_grad=True)
name: fc.bias
value: Parameter containing:
tensor([-2.6341e-03, 3.0005e-03, 6.5581e-04, -2.6909e-02, 6.3637e-03,
1.3260e-02, -1.1178e-02, 2.0639e-02, -3.6373e-03, -1.2325e-02,
...,
,requires_grad=True)

修改全連接層的 weight,由下列結果可看到 weight 的 size 變成 [10, 512]

print(“org: “, checkpoint[‘fc.weight’].shape)
checkpoint[‘fc.weight’] = torch.rand((10, 512))
print(“now: “, checkpoint[‘fc.weight’].shape)
# === output ===
org: torch.Size([1000, 512])
now: torch.Size([10, 512])

修改全連接層的 bias,由下列結果可看到 bias 的 size 變成 [10]

print(“org: “, checkpoint[‘fc.bias’].shape)
checkpoint[‘fc.bias’] = torch.ones(10)
print(“now: “, checkpoint[‘fc.bias’].shape)
# === output ===
org: torch.Size([1000])
now: torch.Size([10])

查看指定網路層的參數值已更改為剛剛設定的值

for k, v in checkpoint.items():
if k in [‘fc.weight’, ‘fc.bias’]:
print(“name:”, k)
print(“value:”, v)
# === output ===
name: fc.weight
value: tensor([[0.8379, 0.6310, 0.5764, ..., 0.0857, 0.5956, 0.9491],
[0.2773, 0.6764, 0.9006, ..., 0.2598, 0.9735, 0.7401],
[0.2910, 0.0501, 0.1576, ..., 0.9090, 0.8763, 0.4543],
...,
[0.0213, 0.0820, 0.9081, ..., 0.3636, 0.6353, 0.1084],
[0.5539, 0.2268, 0.7676, ..., 0.8510, 0.4865, 0.2786],
[0.6941, 0.2103, 0.9243, ..., 0.4558, 0.7547, 0.6159]])
name: fc.bias
value: tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

🔖 名稱

很多時候會使用其他人的模型權重,但權重中所定義的 layer 名稱 (key) 可能跟自定義的模型不一樣。當模型中的 key 與載入的權重中的 key 不匹配,會出現以下錯誤訊息 (以 resnet18 為例)。

RuntimeError: Error(s) in loading state_dict for ResNet:
Missing key(s) in state_dict: "conv1.weight", "bn1.weight", "bn1.bias",...
Unexpected key(s) in state_dict: "resnet.conv1.weight", "resnet.bn1.weight", "resnet.bn1.bias", ...

查看模型的 key

model = models.resnet18()for param_tensor in model.state_dict():
print(“name:”, param_tensor)
print(“value:”, model.state_dict()[param_tensor].size())
# === output ===
name: conv1.weight
value: torch.Size([64, 3, 7, 7])
name: bn1.weight
value: torch.Size([64])
name: bn1.bias
value: torch.Size([64])
name: bn1.running_mean
value: torch.Size([64])
...

接著看權重的 key

checkpoint = torch.load(‘resnet_weights.pth’)for k, v in checkpoint.items():
print(“name:”, k)
print(“value:”, v.size())
# === output ===
name: resnet.conv1.weight
value: torch.Size([64, 3, 7, 7])
name: resnet.bn1.weight
value: torch.Size([64])
name: resnet.bn1.bias
value: torch.Size([64])
name: resnet.bn1.running_mean
value: torch.Size([64])
...

觀察兩者可發現只差在 resnet. ,接著建立新的 OrderedDict,將載入權重的 key 改為跟模型 key 一致,value 則是原本的權重 value。如此一來就能成功讀取權重檔啦~

state_dict = OrderedDict()
for k, v in checkpoint.items():
state_dict[k[len(‘resnet.’):]] = v
model.load_state_dict(state_dict)

新增 layer

若有更動、新增網路架構時,只需要提取 pre-trained model 中跟自定義模型裡相同層的權重參數。

接著就來實作看看吧~ 首先基於 resnet18 新增了一些 layer

class MyResNet18(nn.Module):
def __init__(self, net_block, layers, num_classes=10):
super(MyResNet18, self).__init__()
self.in_channels = 64
self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpooling = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self.net_block_layer(net_block, 64, layers[0])
self.layer2 = self.net_block_layer(net_block, 128, layers[1], stride=2)
self.layer3 = self.net_block_layer(net_block, 256, layers[2], stride=2)
self.layer4 = self.net_block_layer(net_block, 512, layers[3], stride=2)
## ============== 新增的網路層 ============ ##
self.layer5 = nn.Sequential(nn.Conv2d(layers[3], 128, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True))

## ======================================= ##
self.avgpooling = nn.AvgPool2d(7, stride=1)
self.fc = nn.Linear(512 * net_block.expansion, num_classes)
....num_classes = 10model = MyResNet18(basic_block, [2, 2, 2, 2], num_classes)
model_state = model.state_dict()

讀取 pre-trained model,這邊要注意的是需要修改網路最後一層的輸出值為自己的資料集類別數量

checkpoint = torch.load(‘resnet-5c106cde.pth’)checkpoint[‘fc.weight’] = torch.zeros((num_classes, 512))
checkpoint[‘fc.bias’] = torch.zeros(num_classes)

再將 pre-trained model 中與模型相同網路層的權重參數提取至自定義模型中

pretrained_dict = {k: v for k, v in checkpoint.items() if k in model_state}

更新權重參數後載入至模型中,輸出出現 <All keys matched successfully> 就表示成功載入啦~

model_state.update(pretrained_dict)
model.load_state_dict(model_state)
# === output ===
<All keys matched successfully>

刪除 layer

若要刪除 pre-trained model 某些指定的 layer,能夠透過該 layer 的名稱來進行刪除。

以 resnet18 為例,假設我們要刪除的網路層為 layer4.1,首先載入 pre-trained model 權重,查看模型 layer4.1 網路層的名稱。

checkpoint = torch.load(‘resnet-5c106cde.pth’)for k in list(checkpoint.keys()):
if k.startswith(‘layer4.1’):
print(k)
# === output ===
layer4.1.conv1.weight
layer4.1.bn1.running_mean
layer4.1.bn1.running_var
layer4.1.bn1.weight
layer4.1.bn1.bias
layer4.1.conv2.weight
layer4.1.bn2.running_mean
layer4.1.bn2.running_var
layer4.1.bn2.weight
layer4.1.bn2.bias

指定刪除 layer4.1

for k in list(checkpoint.keys()): 
if k.startswith(‘layer4.1’):
del checkpoint[k]

接著驗證是否刪除成功,輸出為 None 表示該指定網路層已被刪除~

import numpy as npfor k in list(checkpoint.keys()):
a = [“None” if not k.startswith(‘layer4.1’) else “Exists” for k in list(checkpoint.keys())]
print(np.unique(a))# === output ===
['None']

--

--