[PyTorch] 4. Custom Dataset Class, Convert image to tensor and vice versa

jun94
jun-devpBlog
Published in
2 min readApr 5, 2020

1. Custom Dataset Class

PyTorch supports two classes, which are torch.utils.data.Dataset and torch.utils.data.DataLoader, to facilitate loading dataset and to make mini-batch without large effort. The basic way to get a mini-batch with those classes is that firstly, define the torch.utils.data.Dataset(or custom Dataset if needed) and send it as an argument to the torch.utils.data.DataLoader, as is given in the last two lines in the below code.

Sometimes, however, we might need our own Dataset class. In this case, thanks to the PyTorch, there is nothing much to do but only need to implement a Class with three functions inherited from torch.utils.data.Dataset.

Figure 1 shows basic functions that need to be implemented for building our custom dataset class.

Figure 1. Three functions inherited from torch.utils.data.Dataset

The code below is an example of how we make the custom dataset and load data out of it.

2. Load Image and Convert it to Tensor

When we feed an image(s) as an input to model, it has to be converted in the form of Torch.Tensor. Also, sometimes we also need to visualize tensor back to Image. This inter-conversion is quite often needed, and therefore, those can be done easily with the help of several Image-handling packages(i.e. PIL) and Pytorch.

The below code is a sample for those work with the description of classes from PyTorch I used.

Figure 2. from here, the description of ToTensor and ToPILImage class

3. Reference

[1] https://pytorch.org/docs/stable/torchvision/transforms.html

[2] https://wikidocs.net/57165

Any corrections, suggestions, and comments are welcome

--

--