How to read tfrecords files in PyTorch !

Soumo Chatterjee
Analytics Vidhya
Published in
3 min readJul 31, 2020

Step 1 → First of all you need to know what are the contents of your data . For understanding, I am going to use the kaggle data for classifying 104 flower classes . So, there are 4 folders , having 3 sub-folders for training , validation and testing images each . So , I will be using only training and validation sub-folders having .tfrec files for our presentation here . NOTE: It has been given that each .tfrec fie in the sub-folder contains the id, label (the class of the sample, for training & validation data) and img (the actual pixels in byte string format).

Sample picture of our dataset

Step 2 → We will use the glob library to grab the file with their path names in training and validation sub-folders

import globtrain_files = glob.glob(‘/kaggle/input/tpu-getting-started/*/train/*.tfrec’)val_files = glob.glob(‘/kaggle/input/tpu-getting-started/*/val/*.tfrec’)

Let’s see what is loaded in train_files variable

print(train_files[:5])

Step 3 → Now, we will be collecting the ids , filenames and images in bytes in three different list variables for training & validation files.

# importing tensorfow to read .tfrec files
import tensorflow as tf

Importing tensorflow first

# Create a dictionary describing the features.
train_feature_description = {
'class': tf.io.FixedLenFeature([], tf.int64),
'id': tf.io.FixedLenFeature([], tf.string),
'image': tf.io.FixedLenFeature([], tf.string),
}

We are creating a dictionary describing the features that are class, id and images in byte string.

def _parse_image_function(example_proto):
return tf.io.parse_single_example(example_proto, train_feature_description)

then creating a function to parse the input tf.Example proto using the dictionary (train_feature_description)

train_ids = []
train_class = []
train_images = []
for i in train_files:
train_image_dataset = tf.data.TFRecordDataset(i)
train_image_dataset = train_image_dataset.map(_parse_image_function)ids = [str(id_features['id'].numpy())[2:-1] for id_features in train_image_dataset] # [2:-1] is done to remove b' from 1st and 'from last in train id names
train_ids = train_ids + ids
classes = [int(class_features['class'].numpy()) for class_features in train_image_dataset]
train_class = train_class + classes
images = [image_features['image'].numpy() for image_features in train_image_dataset]
train_images = train_images + images

Finally, storing the features in 3 different list . You can also create a dataframe using these list for your ease . NOTE: [2:-1] is done to remove b‘ from 1st and from last in train-id names. We can also do the same for our validation .tfrec files

For testing we can do this.

import IPython.display as displaydisplay.display(display.Image(data=train_images[211]))
output for above code

Step 4 → Finally we will be creating our pytorch dataset class with our features extracted i.e. train_ids, train_class , train_images.

Making some imports first

from PIL import Image
import cv2
import albumentations
import torch
import numpy as np
import io
from torch.utils.data import Dataset

Here we are using albumentations library for transformations and then defining our dataset class

For dry run testing,

Creating object for our FlowerDataset Class →

output for the above code

Now we can load this train_dataset & val_dataset objects of FlowerDataset directly into pytorch data loaders.

This is the way we can read .tfrec files in pytorch. Let me know if you have any question , comment or , concerns in comments. Thanks for reading and until then enjoy learning.

--

--

Soumo Chatterjee
Analytics Vidhya

Machine learning and Deep Learning Enthusiast | | Mindtree Mind | | Python Lover