PyTorch 自定義損失函數 (Custom Loss)
Sep 9, 2018 · 5 min read
一個自定義損失函數的類別 (class),是繼承自
nn.Module,進而使用 parent 類別的屬性與方法。
自定義損失函數的類別框架
如下,即是一個自定義損失函數的類別框架。
在__init__方法中,定義 child 類別的 hyper-parameters;而在forward方法中,定義損失函數的計算過程,因此,該 child 類別覆蓋了原先 parent 類別同名的屬性 (attribute) 和方法 (method),實現新的自定義損失函數。
import torch.nn as nnclass customLoss(nn.Module):
def __init__(self):
super(customLoss, self).__init__()
######################
### initialization ###
######################
def forward(self):
###########################
### define forward pass ###
###########################
passcriterion = customLoss()
以分類問題中常見的 CrossEntropyLoss 為例
以下程式碼,以自定義損失函數來實現與官方相同的CrossEntropyLoss。( PyTorch version: 0.4.1)
import torch
import torch.nn as nnclass customLoss(nn.Module):
def __init__(self, weight):
super(customLoss, self).__init__()
self.softmax = nn.Softmax(dim=1)
self.weight = weight
def forward(self, outputs, targets):
# transform targets to one-hot vector
targets_onehot = torch.zeros_like(outputs)
targets_onehot.zero_()
targets_onehot.scatter_(1, targets.unsqueeze(-1), 1)
# nn.CrossEntropyLoss
# combines nn.LogSoftmax() and nn.NLLLoss()
outputs = self.softmax(outputs)
self.weight = self.weight.expand_as(outputs)
loss = -targets_onehot.float() * torch.log(outputs)
return torch.mean(self.weight * loss)# define CrossEntropyLoss with weights
weight = torch.Tensor([1, 5, 10])# define inputs, official and custom loss
outputs = torch.Tensor([[0.9, 0.5, 0.05], [0.01, 0.2, 0.7]])
targets = torch.Tensor([0, 1]).long()criterion = nn.CrossEntropyLoss(weight=weight)
custom_criterion = customLoss(weight=weight)# run metrics
loss = criterion(outputs, targets)
custom_loss = custom_criterion(outputs, targets)print ('official loss: ', loss.item())
print ('custom loss: ', custom_loss.item())
顯示輸出如下,
official loss: 1.1616348028182983
custom loss: 1.161634922027588此外,官方提供許多張量操作的方法[1],在定義損失函數的forward方法時,可以盡量採用,增加執行效率。
nn.Module可以看成是 parameters 的容器,官方稱為Base class for all neural network modules. ,同學們可以透過繼承此類別,自定義所需的模型與損失函數,而過程中僅需定義forward方法,反向傳播的定義會在宣告張量為requires_grad=True時被自動產生,實現 PyTorch torch.autograd的特性。
感謝您的閱讀,如果文章有益請在底下長按拍手
有任何問題歡迎在底下留言或是來信交流wanju.ts@gmail.com
