How FASHABLE achieves SoA realistic AI generated images using PyTorch and Azure Machine Learning

Orlando Ribas Fernandes
PyTorch
Published in
8 min readFeb 9, 2023
These products and model don’t exist — all generated by Fashable AI

Authors: Abubakar Zakari and Orlando Ribas Fernandes

Fashable is a company born at XNFY Lab (a joint initiative with Microsoft). The company’s main goal is to revolutionize the world of fashion with ethical Artificial Intelligence (AI) technologies built on PyTorch framework. Fashable is focused on developing AI models that generates synthetic contents for the global fashion industry. The Fashion industry has been criticized in recent years because it generates a lot of waste and is responsible for up to 10% of global carbon dioxide output. Fashable has stepped up to address this issue by introducing multiple AI solutions that generates realistic personalized consumer garments without actually producing them to help in reducing carbon footprint. This will help the fashion brands make informed decisions without investing in experimental products and also reducing the industry’s carbon footprint globally. Hence, in Fashable, our IP models utilize modern approaches, such as Generative Adversarial Networks (GANs), best seller analysis, custom dataset creation, and so on to resolve such problems.

Furthermore, Fashable’s key goal is to have a powerful AI toolkit to support the fashion industry. However, this toolkit can be utilize in metaverse, marketplaces and social media spaces as well. The need for new content is in constant demand by Fashion brands. Hence, war for new content has never been so intense and only AI can generate very realistic images to solve that content needs. Fashable proprietary AI generated fashion technology can help brands reduce the time spent on research, design, and content creation, at the same time reducing the possibility of copyright infringements and focusing designers on high-added-value tasks.

Abubakar with a AI Generated hoodie

Fashable AI technologies utilize PyTorch framework for development. Pytorch is a framework for building deep-learning models that provides debuggability, ease of use and performance. Here, we discuss how we utilized PyTorch and Azure Machine Learning to build our SoA AI models. Furthermore, we will highlight the performance gains and cost savings using Distributed Data Parallel (DDP).

HOW FASHABLE UTILIZES MACHINE LEARNING TO POWER CONTENT GENERATION

Fashable develops Deep Learning (DL) models which generate synthetic content for the fashion industry and focus on how to enhance the garment’s realistic view via research and development (R&D) activities to reduce carbon footprint around the globe. Fashable chooses PyTorch for full stack from model development, training to inference. The joint collaboration with PyTorch and Azure Machine Learning enables us to further optimize our workflow to be more cost effective by speeding up model training time and better utilization of GPUs as expensive compute resources. Therefore, PyTorch debuggability and ease of use has helped us to speed up model development and utilizing PyTorch distributed features offer us to scale our training jobs efficiently on Azure Machine Learning.

TRAINING GANS ON GIGABYTES IMAGES

Fashable owns multiple IPs on GAN models, XGAN is one of recent developments that we discuss here. Looking at GAN’s and image-based ML models in general, one of the main challenges faced is the size of images used in the dataset, as the image dimension for an advanced GAN can be about 1024 x 1024 resolution. A given dataset can contain hundreds of thousands of such images which can be gigabytes in size. Data loading is one of major challenges in training GANs. Without such quality images, the image output quality of GANs in terms of clarity and detail can be partially affected significantly, so image resizing is not usually an option. This enforces working with lower batch sizes.

Azure Machine Learning provides easy to use API and performant object stores for storing and loading training images in distributed training infrastructure on compute clusters.

To dive more into GAN model architecture, let’s look at the high-level architecture of a basic GAN model. As you can see in the image below, the model is composed of two key components, which are the Generator and the Discriminator. Firstly, the Generator generates fake images, while the Discriminator job is to distinguish between fake and real images.

High-level Architecture of a Normal Basic GAN Model.

In general, Fashable GAN model (XGAN) uses several techniques for image synthesis which are computationally very fast, efficient, and has been shown to aid in the generation of quality output image. These techniques are path length regularization, R1 regularization, wavelet technique, and data augmentation technique. The data augmentation used by our model augment the input images (both real and fake) to the Discriminator with differentiable transformations. Performing these transformations encourages the Discriminator to learn meaning-ful details of the images and allows GANs to produce high-quality images on small dataset.

This is particularly important as creating a large in-house custom dataset is somewhat challenging. Next we will discuss the distributed training settings that help to make the training much more efficient and cost effectively.

Speed up Training by Moving from DP to DDP

Reducing the training time has a great impact on cost savings and also is aligned with our mission to reduce the carbon footprint in the Fashion industry. Initially, we had the PyTorch vanilla Data Parallel set up for our training jobs. PyTorch recommends using Distributed Data Parallel (DDP).

Moving from Data Parallel in PyTorch to DDP (read more on the comparison here) has paid off greatly as we could cut the training time ~7x.

In our XGAN model, using DP, on average, we were able to generate a single epoch in every 35 minutes. However, using DDP, we were able to generate a single epoch in every 5.58 minutes.

DISTRIBUTED DATA PARALLEL TRAINING

To scale our training jobs, we used DDP. DDP parallelizes the model training over multiple devices, where each device owns a full replica of the model and process a shared of data and sync the gradients after each backward pass.

DDP uses multiprocessing unlike the DP which utilizes multi-threading, so DP is usually slower than DDP even on a single machine due to (Global Interpreter Lock) GIL contention across threads, per-iteration replicated model, and additional overhead introduced by scattering inputs and gathering outputs. Also DDP is composable with model parallel that supports cases where model is large enough that wouldn’t fit into one GPU. Our training stack is built on top of Azure Machine Learning and PyTorch. This let us leverage the high performance compute clusters with A100 GPUs and efficient object stores (Azure containers for PyTorch) along with full support of the latest PyTorch disturbed training features. This environment (Azure Machine Learning ) help Fashable to store multiple trained models and their various versions, respectively.

The basic process of DDP training in our model is outlined as follows.

Step1: Distributed training initialization

if self.distributed:
self.world_size = int(os.environ["WORLD_SIZE"])
self.rank = int(os.environ["RANK"])
self.local_rank = int(os.environ["LOCAL_RANK"])

print(“init_process_group1, flush=True)

dist.init_process_group(backend='nccl', init_method='env://', world_size=self.world_size, rank=self.rank, timeout=timedelta(hours=2))

With the above, we initialize the distributed environment via a parameter set in the model hyperparameter that if set as ‘True’, the environment set the world_size, rank, and local_rank for the devices and set of nodes. For this, we use Pyton 3.8 and PyTorch 1.11 with a docker based container.

Step2: Model wrapping using Distributed Data Parallel

XGAN = self.gan.cuda()

XGAN = nn.SyncBatchNorm.convert_sync_batchnorm(XGAN)

Local_rank = int(os.environ[LOCAL_RANK])
XGAN = nn.parallel.DistributedDataParallel(XGAN, device_ids=[local_rank])

XGAN.to(self.device)

With this, the model (XGAN) will be wrapped in nn.parallel.DistributedDataParallel with local_rank set as the device_ids. Hence, if each process has the current local rank, tensor.cuda() or model.cuda() can be called correctly through the given set script for training.

Step3: Utilizing DistributedSampler in our Dataloader

From the code snippet below, the reader can see that in a distributed environment, the dataset needs to be passed through a DistributedSampler with the nodes and ranks before loading it in the Dataloader. Hence, the variable representing the DistributedSampler (represented as sampler in our code), will be used as part of the parameters in the Dataloader method. This will help in the distribution of the data into multiple nodes and GPUs used for training.

def init_dataset(self, data_path, target_img_size, pad_instead_of_resize, batch_size):
if pad_instead_of_resize:
data_transforms = transforms.Compose([
SquarePad(),
transforms.ToTensor(),
transforms.Resize(
target_img_size),
transforms.Normalize(
[0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])
else:
data_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Resize(
target_img_size),
transforms.CenterCrop(
target_img_size),
transforms.Normalize(
[0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])

self.dataset = Dataset(path=data_path,transform=data_transforms)

if self.distributed:
dist.init_process_group(backend='nccl', init_method='env://', world_size=self.world_size, rank=self.rank, timeout=timedelta(hours=2))

sampler = torch.utils.data.distributed.DistributedSampler(self.dataset,num_replicas=self.world_size,rank=self.rank)

self.dataloader = DataLoader(self.dataset,
batch_size=batch_size,
num_workers=2,
#shuffle=True,
sampler=sampler,
drop_last=True,
pin_memory=True)

CONCLUSION

For evaluation, GANs need their own specific metrics to measure performance. For XGAN, Fashable uses Frechet inception distance (FID) metric. The FID metric measures the similarity between the training dataset and the generated images using the features of a hidden layer of an Inception-v3 network trained on ImageNet. A low FID value correlates with high image quality and high diversity.

Hence, this metric is used to assess the quality of images generated by our generative model. From the image below, using XGAN on our custom dataset, Fashable is consistently able to achieve 2.63 FID score. Meaning, the lower the score, the better the model performance. In comparison with the state-of-the-art models as shown in the figure, XGAN by Fashable have performed way better. One of the state-of-the-art named Imagen proposed by Google, regarded as a high performing model, achieves only 7.27 FID score as shown in the figure below. It is significantly important for the readers to understand that the dataset used by Google’s Imagen is COCO, which is a well-structured and organized dataset in comparison to Fashable custom dataset. Yet, XGAN achieves better results.

Results from Google’s Imagen — https://imagen.research.google/
Fashable XGAN — FID results

In order to attain such achievements and also further meet our future goals, Fashable needs to rely on an ML development framework that:

  • Facilitates quick iteration and easy extension
  • Large model training and inference that supports huge image data size, and
  • Provide ease of use and robustness

As we’ve demonstrated, PyTorch offers us all these capabilities and more. We are incredibly excited about the future of PyTorch and cannot wait to see what other impactful challenges we can solve using the framework under Azure Machine Learning.

Microsoft Customer Story article — “Fashable reimagines the future of fashion design with Azure Machine Learning and PyTorch”

Microsoft Customer Story article link:

Contact us, for collaborations or if you want to create your fashion collection: partnerships@fashable.ai

For more information about Fashable: https://www.fashable.ai/

Follow us: LinkedIn, Facebook, Instagram and Twitter.

Acknowledgment

We would like to thank the PyTorch Team, including Geeta Chauhan and Hamid Shojanazeri for their assistance and helpful suggestions over the last few months. We would also like to thank Alan Weaver from Microsoft AI Black Belts team for his constant support.

--

--