How to read tfrecords files in PyTorch !
Step 1 → First of all you need to know what are the contents of your data . For understanding, I am going to use the kaggle data for classifying 104 flower classes . So, there are 4 folders , having 3 sub-folders for training , validation and testing images each . So , I will be using only training and validation sub-folders having .tfrec files for our presentation here . NOTE: It has been given that each .tfrec fie in the sub-folder contains the id
, label
(the class of the sample, for training & validation data) and img
(the actual pixels in byte string format).
Step 2 → We will use the glob library to grab the file with their path names in training and validation sub-folders
import globtrain_files = glob.glob(‘/kaggle/input/tpu-getting-started/*/train/*.tfrec’)val_files = glob.glob(‘/kaggle/input/tpu-getting-started/*/val/*.tfrec’)
Let’s see what is loaded in train_files variable
print(train_files[:5])
Step 3 → Now, we will be collecting the ids , filenames and images in bytes in three different list variables for training & validation files.
# importing tensorfow to read .tfrec files
import tensorflow as tf
Importing tensorflow first
# Create a dictionary describing the features.
train_feature_description = {
'class': tf.io.FixedLenFeature([], tf.int64),
'id': tf.io.FixedLenFeature([], tf.string),
'image': tf.io.FixedLenFeature([], tf.string),
}
We are creating a dictionary describing the features that are class, id and images in byte string.
def _parse_image_function(example_proto):
return tf.io.parse_single_example(example_proto, train_feature_description)
then creating a function to parse the input tf.Example proto using the dictionary (train_feature_description)
train_ids = []
train_class = []
train_images = []for i in train_files:
train_image_dataset = tf.data.TFRecordDataset(i)train_image_dataset = train_image_dataset.map(_parse_image_function)ids = [str(id_features['id'].numpy())[2:-1] for id_features in train_image_dataset] # [2:-1] is done to remove b' from 1st and 'from last in train id names
train_ids = train_ids + idsclasses = [int(class_features['class'].numpy()) for class_features in train_image_dataset]
train_class = train_class + classesimages = [image_features['image'].numpy() for image_features in train_image_dataset]
train_images = train_images + images
Finally, storing the features in 3 different list . You can also create a dataframe using these list for your ease . NOTE: [2:-1] is done to remove b‘ from 1st and ’ from last in train-id names. We can also do the same for our validation .tfrec files
For testing we can do this.
import IPython.display as displaydisplay.display(display.Image(data=train_images[211]))
Step 4 → Finally we will be creating our pytorch dataset class with our features extracted i.e. train_ids, train_class , train_images.
Making some imports first
from PIL import Image
import cv2
import albumentations
import torch
import numpy as np
import io
from torch.utils.data import Dataset
Here we are using albumentations library for transformations and then defining our dataset class
For dry run testing,
Creating object for our FlowerDataset Class →
Now we can load this train_dataset & val_dataset objects of FlowerDataset directly into pytorch data loaders.
This is the way we can read .tfrec files in pytorch. Let me know if you have any question , comment or , concerns in comments. Thanks for reading and until then enjoy learning.