機器學習分類器在沈浸式環境之應用 — — 2. Pytorch Custom Dataset

ExcitedMail
SWF Lab
Published in
10 min readApr 10, 2023

之前寫了一篇自己訓練 CNN 來進行 Image Classification (Pytorch),不過不是每次都能有別人準備好的 Dataset 可以用,這次就來介紹從網路下載資料下來後要如何寫 Custom Dataset。

Dataset 負責的事情是把資料的資訊抓出來,在這之前我們需要給定資料的位置、資料的轉換方式以及其他額外需要的資訊,以下會用簡單的貓狗圖片的 Dataset 做示範。

開始之前先謝謝 Corn 以及 QQAI 幫我 Review 這篇文章 (痛哭流涕

Table of Contents

  • Intro.
  • Custom Dataset
  • Check Result
  • Conclusion
  • Reference

Intro.

這次的資料是從 Kaggle 的 competition 來的,請先到這邊下載好檔案,Kaggle 可能會要求先加入此次 competition 才能下載,另外記得建立新的資料夾做這次的內容。

下載好解壓縮後可以看到裡面有 sampleSubmission.csv 還有 test1.zip 以及 train.zip,把兩個都解壓縮得到 test1 及 train 資料夾,創一個 data 資料夾存這兩個圖片資料夾,以及一個 src 資料夾存 code ,像下面這個格式:

在開始寫程式之前,我們需要先知道這次資料的格式長怎麼樣。打開 train 資料夾,可以看到他是用 <cat or dog>.<number>.jpg 來命名的,如下圖:

而 test1 的資料夾內則沒有貓或狗的答案,畢竟 test 就是題目,不該一開始就知道答案。

觀察完這些之後就可以開始正式寫 Dataset 了!

Custom Dataset

在 src 中建立一個 main.py ,首先 import 需要的 library:

import os
from torch.utils.data import Dataset
from PIL import Image

接著需要繼承 Dataset 這個類別,以及完成底下三個 function ,除了 __init__ 可以自己定義輸入的值以外,其他兩個基本上輸入不能修改。

__init__ 是在最一開始建立 Dataset 被呼叫的 function,我們需要給它資料夾的位置、圖片轉換的形式(需要轉換成 torch.Tensor 才能餵給 model)以及告訴它這是否是 training 的資料(train 及 test 只有前者有答案)。
我們先把資料都存在 self 中,離開 __init__ 後 Dataset 才會繼續記得這些值,另外要把所有圖片名稱存在 self.files 中,讓之後的 __getitem__ 知道自己現在需要呼叫哪張圖片。

    def __init__(self, path, tfm, train=False):
# 資料的位置
self.path = path
# 資料的 list
self.files = sorted([os.path.join(path,x) for x in os.listdir(path) if x.endswith(".jpg")])
# 資料轉換的形式
self.transform = tfm
# 是否為 training
self.train = train

__len__ 的作用是回傳 Dataset 中資料數量,這邊直接呼叫 self.files 的長度即可。

    def __len__(self):
# 回傳資料 list 的長度
return len(self.files)

__getitem__ 會是這三個之中最常被呼叫的 function ,在建立完 dataset 之後,會交給 DataLoader 把資料一批一批抓出來,輸入會有一個 index 值,這邊姑且命名為 idx ,我們需要先把檔案名稱從 self.files 抓出來,把圖片讀取出來後交給 transform 轉換成 torch.Tensor 的型態。

    def __getitem__(self, idx):
# 抓出檔案名稱
fname = self.files[idx]
# 讀取並做轉換
im = Image.open(fname)
im = self.transform(im)

# 貓設定為 0, 狗設定為 1
classList = ['cat', 'dog']
if(self.train == True):
label = fname.split("/")[-1].split(".")[0]
label = classList.index(label)
else:
label = -1

return im, label

到這邊就算是完成 Dataset 的部份了!接著我們把 DataLoader 吐出來的圖片印出來看看是否正確。

Check Result

加上前面的 Dataset ,我們可以執行以下程式:

import os
from torch.utils.data import Dataset
from PIL import Image

class CatDogDataset(Dataset):

def __init__(self, path, tfm, train=False):
self.path = path
self.files = sorted([os.path.join(path,x) for x in os.listdir(path) if x.endswith(".jpg")])
self.transform = tfm
self.train = train

def __len__(self):
return len(self.files)

def __getitem__(self, idx):
fname = self.files[idx]
im = Image.open(fname)
im = self.transform(im)

classList = ['cat', 'dog']
if(self.train == True):
label = fname.split("/")[-1].split(".")[0]
label = classList.index(label)
else:
label = -1

return im, label


from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.utils import make_grid
# 把圖片排成方格狀
def image_grid(raw_imgs, rows, cols):
assert len(raw_imgs) == rows*cols

imgs = []
transform = transforms.ToPILImage()
for i in range(len(raw_imgs)):
imgs.append(transform(raw_imgs[i]))

w, h = imgs[0].size
grid = Image.new('RGB', size=(cols*w, rows*h))
grid_w, grid_h = grid.size

for i, img in enumerate(imgs):
grid.paste(img, box=(i%cols*w, i//cols*h))
return grid
# 簡單的轉換
tfm = transforms.Compose([
transforms.Resize(256),
transforms.RandomCrop(224),
transforms.ToTensor(),
])
# 建立 Dataset
train_set = CatDogDataset("./data/train", tfm=tfm, train=True)
# 交給 DataLoader
train_loader = DataLoader(train_set, batch_size=16, shuffle=True, num_workers=0, pin_memory=True)
# 取得一個 batch
batch = next(iter(train_loader))
imgs, labels = batch
grid = image_grid(imgs, rows=4, cols=4)
grid.show()
# 把 label 印出來
for idx, label in enumerate(labels):
print(idx, label.item())

參考網路上的這篇文章,把圖片排列成方便檢視的 4x4 方格狀。雖然到時候丟給 model 時需要轉換成 torch.Tensor 的型態,但為了印出來確認輸出是否有錯,我們還是要把它傳換回 PIL 的型態印出來看。

因為在 Dataset 中把貓放在 list 的第一位,所以貓的 label 為 0 且狗的 label 為 1,可以看到這邊有吐出正確的圖片跟 label 了。

另外也可以把 test 的部份印出來看看,只要把 train_set 以下的部份換成下面這段即可:

test_set = CatDogDataset("./data/test1", tfm=tfm, train=False)
test_loader = DataLoader(test_set, batch_size=16, shuffle=True, num_workers=0, pin_memory=True)

batch = next(iter(test_loader))
imgs, labels = batch
grid = image_grid(imgs, rows=4, cols=4)
grid.show()
for idx, label in enumerate(labels):
print(idx, label.item())

這次吐出的 label 都是 -1 ,跟前面 Dataset 中設定的相同。

Conclusion

以上就是 Custom Dataset 的示範!這次的 transform 只有用 Resize 、 RandomCrop 以及 ToTensor 來轉換成符合 model 輸入的型態,實際上會用很多酷酷的方法把圖片做各種變化,可以讓同一張圖片變出不同樣式,也能讓 model 學習的更好,詳細的部份可以參考這邊

Reference

https://www.kaggle.com/competitions/dogs-vs-cats/data
https://stackoverflow.com/questions/37921295/python-pil-image-make-3x3-grid-from-sequence-images
https://www.tutorialspoint.com/how-to-convert-a-torch-tensor-to-pil-image
https://pytorch.org/vision/stable/transforms.html

--

--