PyTorch 自定義損失函數 (Custom Loss)

rowan.ts
rowan.ts
Sep 9, 2018 · 5 min read

一個自定義損失函數的類別 (class),是繼承自nn.Module,進而使用 parent 類別的屬性與方法。

自定義損失函數的類別框架

如下,即是一個自定義損失函數的類別框架。

__init__方法中,定義 child 類別的 hyper-parameters;而在forward方法中,定義損失函數的計算過程,因此,該 child 類別覆蓋了原先 parent 類別同名的屬性 (attribute) 和方法 (method),實現新的自定義損失函數。

import torch.nn as nnclass customLoss(nn.Module):
def __init__(self):
super(customLoss, self).__init__()
######################
### initialization ###
######################

def forward(self):
###########################
### define forward pass ###
###########################
pass
criterion = customLoss()

以分類問題中常見的 CrossEntropyLoss 為例

以下程式碼,以自定義損失函數來實現與官方相同的CrossEntropyLoss。( PyTorch version: 0.4.1)

import torch
import torch.nn as nn
class customLoss(nn.Module):
def __init__(self, weight):
super(customLoss, self).__init__()
self.softmax = nn.Softmax(dim=1)
self.weight = weight

def forward(self, outputs, targets):
# transform targets to one-hot vector
targets_onehot = torch.zeros_like(outputs)
targets_onehot.zero_()
targets_onehot.scatter_(1, targets.unsqueeze(-1), 1)

# nn.CrossEntropyLoss
# combines nn.LogSoftmax() and nn.NLLLoss()
outputs = self.softmax(outputs)
self.weight = self.weight.expand_as(outputs)
loss = -targets_onehot.float() * torch.log(outputs)
return torch.mean(self.weight * loss)
# define CrossEntropyLoss with weights
weight = torch.Tensor([1, 5, 10])
# define inputs, official and custom loss
outputs = torch.Tensor([[0.9, 0.5, 0.05], [0.01, 0.2, 0.7]])
targets = torch.Tensor([0, 1]).long()
criterion = nn.CrossEntropyLoss(weight=weight)
custom_criterion = customLoss(weight=weight)
# run metrics
loss = criterion(outputs, targets)
custom_loss = custom_criterion(outputs, targets)
print ('official loss: ', loss.item())
print ('custom loss: ', custom_loss.item())

顯示輸出如下,

official loss:  1.1616348028182983
custom loss: 1.161634922027588

此外,官方提供許多張量操作的方法[1],在定義損失函數的forward方法時,可以盡量採用,增加執行效率。

nn.Module可以看成是 parameters 的容器,官方稱為Base class for all neural network modules. ,同學們可以透過繼承此類別,自定義所需的模型與損失函數,而過程中僅需定義forward方法,反向傳播的定義會在宣告張量為requires_grad=True時被自動產生,實現 PyTorch torch.autograd的特性。


感謝您的閱讀,如果文章有益請在底下長按拍手
有任何問題歡迎在底下留言或是來信交流wanju.ts@gmail.com


Welcome to a place where words matter. On Medium, smart voices and original ideas take center stage - with no ads in sight. Watch
Follow all the topics you care about, and we’ll deliver the best stories for you to your homepage and inbox. Explore
Get unlimited access to the best stories on Medium — and support writers while you’re at it. Just $5/month. Upgrade