Tutorial

Experiment Tracking Essentials: Finding the Right Tool

Gradio’s Custom Dashboards vs Wandb’s Built-In Tools for Training Diffusion Models

Anca Ioana Muscalagiu
Decoding ML

--

Photo by Stephen Dawson on Unsplash

Running experiments without tracking tools is like sailing without a compass — you may reach a destination, but is it the right one?

Experimental tracking is an essential component of any research workflow, especially in machine learning and data science. This is why it’s crucial to pay extra attention to how you choose and configure your tracking tools.

By the end of this article, you’ll learn how to:

  • Find the most suitable tool for your project
  • Set up training for diffusion models
  • Utilize Wandb’s built-in tools for tracking and analysis
  • Build custom dashboards with Gradio to visualize your experiments

But with so many tools out there, how do you choose the best one for your needs?

Finding the right tool can feel like navigating a maze, especially when you need a solution that captures every detail without overwhelming complexity. The real challenge is discovering a tool that strikes the perfect balance — one that’s intuitive to use but powerful enough to provide the deep insights needed to push your research forward.

To put this into practice, we’ll dive into a research scenario centered around training a diffusion model to generate images of butterflies. We’ll demonstrate effective tracking of this process using two Python notebooks available on Google Colab:

Table of Contents

  1. Basics of Diffusion Models
  2. Wandb’s Built-In Tools
  3. Custom Dashboards with Gradio
  4. Comparative Analysis
  5. Choosing the Right Tool: Pros, Cons, and Recommendations

1. Basics of Diffusion Models

Introduction to Diffusion

Diffusion models are probabilistic generative models that gradually transform data by adding noise and then learning to reverse this process to generate new samples. They simulate a diffusion process, where data is diffused into noise and then reconstructed, producing high-quality outputs like images, audio, or text.

Key Aspects of Diffusion Models

  1. Forward Process: The original data is gradually corrupted by adding small amounts of noise at each step, resulting in increasingly noisy versions until it becomes pure noise. The number of steps defines the length of the diffusion process.
  2. Reverse Process: This is where the key transformation takes place. A neural network is trained to reverse the noise addition, step by step. Starting with the highly noisy data from the final step, the network works backward, gradually removing the noise until it reconstructs something very close to the original, uncorrupted data.
  3. Neural Network Architecture: The neural network typically used in diffusion models is based on the U-Net architecture. U-Net is particularly effective because it captures both fine-grained local details and broader global features.
Overview of the Diffusion Process

Training Setup for a Diffusion Model

First, we need to install the necessary libraries to set up our environment for training a diffusion model. We’ll primarily be using the diffusers package from Hugging Face [1], which is a popular library for working with diffusion models.

!pip install -q -U pyarrow==14.0.1 torch==2.3.1 torchaudio==2.3.1 gcsfs==2024.3.1 accelerate==0.33.0 torchvision==0.18.1 transformers==4.44.0 pytorch-fid==0.3.0 datasets==2.19.2 diffusers==0.30.0

To efficiently manage and organize the various settings required for training our diffusion model, we define a configuration class. The TrainingConfig class serves this purpose by encapsulating all the key parameters needed for the training process in one place:

from dataclasses import dataclass

@dataclass
class TrainingConfig:
project_name: str = "diffusion-butterflies" # the name of the project
logger_name: str | None = None # the logger used by accelerate
image_size: int = 128 # the generated image resolution
train_batch_size: int = 16 # the batch size used during training
eval_batch_size: int = 16 # how many images to sample during evaluation
num_epochs: int = 50 # total number of training epochs
dataset_name: str = "huggan/smithsonian_butterflies_subset" # the name of the huggingface repository containing the dataset
gradient_accumulation_steps: int = 1 # number of steps to accumulate gradients before performing an optimizer step
learning_rate: float = 1e-4 # the initial learning rate for the optimizer
lr_warmup_steps: int = 500 # the number of steps to warm up the learning rate from 0 to the initial learning rate
save_image_epochs: int = 5 # the interval (in epochs) at which generated images are saved
save_model_epochs: int = 5 # the interval (in epochs) at which the model is saved
mixed_precision: str = "fp16" # the precision type for training; "no" for float32, "fp16" for mixed precision
output_dir: str = "ddpm-butterflies-128" # the directory where model outputs (images, checkpoints) are saved
overwrite_output_dir: bool = True # overwrite the output directory if it already exists
seed: int = 0 # the seed for random number generation to ensure reproducibility

config = TrainingConfig()

In the following code snippet, we use the butterfly image dataset, often considered the “Hello World” of diffusion models due to its simplicity. This dataset contains 1,000 samples, making it ideal for beginners.

We load the dataset and then visualize the first images to ensure everything is in the correct format:

from datasets import load_dataset
import matplotlib.pyplot as plt
import logging

logging.getLogger("datasets").setLevel(logging.ERROR) # or logging.CRITICAL
dataset = load_dataset(config.dataset_name, split="train")
fig, axs = plt.subplots(1, 4, figsize=(16, 4))
for i, image in enumerate(dataset[:4]["image"]):
axs[i].imshow(image)
axs[i].set_axis_off()

print(dataset)
fig.show()
Image Samples from our Dataset

We proceed by setting up a preprocessing pipeline which ensures the images are properly formatted and ready for training the diffusion model.

from torchvision import transforms

preprocess = transforms.Compose(
[
transforms.Resize((config.image_size, config.image_size)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)

Next, we select the first 16 images from our dataset to serve as reference images, which will be used as ground truth for calculating the Fréchet Inception Distance (FID) score later during evaluation phase, in which our trained model generates a grid of images after learning the diffusion process.

import os
from PIL import Image
from diffusers.utils import make_image_grid
from typing import List

def save_grid_image(images: List[Image.Image], rows: int, cols: int, output_path: str) -> None:
"""
Creates a grid of images, resizes the grid, and saves it to the specified output path.
Args:
images (List[Image.Image]): A list of PIL Image objects to be arranged in a grid.
rows (int): The number of rows in the image grid.
cols (int): The number of columns in the image grid.
output_path (str): The file path where the resulting grid image will be saved.

"""
grid = make_image_grid(images, rows=rows, cols=cols)
new_width = grid.width // 2
new_height = grid.height // 2
grid = grid.resize((new_width, new_height), Image.LANCZOS)
grid.save(output_path)

# save images individually for computing the FID
reference_images_dir = "reference_images"
os.makedirs(reference_images_dir, exist_ok=True)
reference_images_dataset = dataset.select(range(16))
dataset = dataset.select(range(16, len(dataset)))
images = []
for idx, example in enumerate(reference_images_dataset):
image = example['image'].convert("RGB")
image = preprocess(image)
image = transforms.ToPILImage()(image)
images.append(image)
image.save(os.path.join(reference_images_dir, f"{idx}.png"))

# save images as grid for visualisation
reference_grid_path = os.path.join(reference_images_dir, "reference_image_grid.png")
save_grid_image(images, rows=4, cols=4, output_path=reference_grid_path)

Furthermore, we apply our preprocessing pipeline to the rest of the dataset used in the training.

from typing import Dict, List
from PIL import Image

def transform(examples: Dict[str, List[Image.Image]]) -> Dict[str, List[Image.Image]]:
"""
Applies preprocessing to a list of images by converting them to RGB, resizing,
flipping, converting to tensor, and normalizing.

Args:
examples (Dict[str, List[Image.Image]]): A dictionary containing a list of images
under the key "image".

Returns:
Dict[str, List[Image.Image]]: A dictionary containing the preprocessed images
under the key "images".
"""
images = [preprocess(image.convert("RGB")) for image in examples["image"]]
return {"images": images}

dataset.set_transform(transform)

The noise-adding process is a key step in simulating how an image is gradually corrupted during the forward diffusion process in a diffusion model.

In the following code block we define the add_noise function, which simulates the forward diffusion process. We apply the noise-adding function on the first image of the dataset to visualize its effects.

import torch
from PIL import Image
from diffusers import DDPMScheduler

def add_noise(images, noise_scheduler, timesteps):
"""
Adds noise to the given images using the specified noise scheduler and timesteps.

Args:
images (torch.Tensor): The input images to which noise will be added.
noise_scheduler: The noise scheduler that determines the noise level.
timesteps (torch.Tensor): The timesteps at which noise is added.

Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing the noisy images and the noise added.
"""
# Generate random noise with the same shape as the images
noise = torch.randn(images.shape, device=images.device, dtype=images.dtype)

# Add noise to the images using the noise scheduler
noisy_images = noise_scheduler.add_noise(images, noise, timesteps)

return noisy_images, noise

# select the first image from the dataset and add a batch dimension
sample_image = dataset[0]["images"].unsqueeze(0)

# initialize the noise scheduler with 1000 timesteps
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)

# specify the timestep at which noise is added (e.g., at step 50)
timesteps = torch.LongTensor([50])

# add noise to the sample image using the noise scheduler
noisy_image, noise = add_noise(sample_image, noise_scheduler, timesteps)

# convert the noisy image back to a PIL Image for visualization
Image.fromarray(((noisy_image.permute(0, 2, 3, 1) + 1.0) * 127.5).type(torch.uint8).numpy()[0])

The next step involves setting up the key components required to train the diffusion model. This includes configuring the DataLoader, defining the UNet model, and setting up the optimizer and learning rate scheduler.

from diffusers import UNet2DModel
import torch
from diffusers.optimization import get_cosine_schedule_with_warmup

train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True)
model = UNet2DModel(
sample_size=config.image_size, # the target image resolution
in_channels=3, # the number of input channels, 3 for RGB images
out_channels=3, # the number of output channels
layers_per_block=2, # how many ResNet layers to use per UNet block
block_out_channels=(128, 128, 256, 256, 512, 512), # the number of output channels for each UNet block
down_block_types=(
"DownBlock2D", # a regular ResNet downsampling block
"DownBlock2D",
"DownBlock2D",
"DownBlock2D",
"AttnDownBlock2D", # a ResNet downsampling block with spatial self-attention
"DownBlock2D",
),
up_block_types=(
"UpBlock2D", # a regular ResNet upsampling block
"AttnUpBlock2D", # a ResNet upsampling block with spatial self-attention
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
),
)
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
lr_scheduler = get_cosine_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=config.lr_warmup_steps,
num_training_steps=(len(train_dataloader) * config.num_epochs),
)

Moreover, we define the evaluate function, which is designed to assess the diffusion model's performance during training by generating a batch of images through the model's image generation pipeline.

In order to evaluate the quality of the generated images, the function calculates the Fréchet Inception Distance (FID) score, which measures how closely the distribution of the generated images aligns with that of real, reference images.

A lower FID score indicates greater similarity between the generated and real images, reflecting better model performance.

from diffusers import DDPMPipeline
from diffusers.utils import make_image_grid
from pytorch_fid import fid_score
from typing import Any
import os


def evaluate(config: Any, epoch: int, pipeline: Any, global_step: int) -> None:
"""
Evaluates the model by generating images, saving them, and computing the FID score.

Args:
config (Any): The configuration object containing various parameters for training and evaluation.
epoch (int): The current epoch number during training.
pipeline (Any): The image generation pipeline ( the diffusion model )
global_step (int): The current step number in the training process.
"""

# generate images using the model pipeline
generated_images = pipeline(
batch_size=config.eval_batch_size,
generator=torch.Generator(device='cpu').manual_seed(config.seed),
).images

# save the images individually
test_dir = os.path.join(config.output_dir, "samples")
os.makedirs(test_dir, exist_ok=True)
generated_image_dir = os.path.join(test_dir, f"generated_epoch_{epoch}")
os.makedirs(generated_image_dir, exist_ok=True)
for idx, img in enumerate(generated_images):
img.save(os.path.join(generated_image_dir, f"{idx}.png"))

# save the image grid
generated_grid_path = f"{test_dir}/{epoch:04d}.png"
save_grid_image(generated_images, rows=4, cols=4, output_path=generated_grid_path)

# compute the FID value using the reference images and the generated images
fid_value = fid_score.calculate_fid_given_paths(
[reference_images_dir, generated_image_dir],
batch_size=config.eval_batch_size,
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
dims=2048,
)
# Store epoch and fid value measured
track_metrics_per_epoch(fid_value,epoch)

NOTE: The track_metrics_per_epoch function is defined according to the experimental tracker utilized in each Python Notebook.

In the following section, we define our main training loop for the diffusion model using the Accelerate library to simplify distributed and mixed-precision training.

As we progress through each epoch, we add noise to our input images, predict the noise residual using the model, compute the loss, and update the model’s parameters. Throughout the training process, we actively track and log key metrics such as loss and learning rate, while also saving model checkpoints and evaluating performance at specified intervals.

from accelerate import Accelerator
from huggingface_hub import create_repo, upload_folder
from tqdm.auto import tqdm
from pathlib import Path
import os
import torch.nn.functional as F
from typing import Any

def train_loop(
config: Any,
model: torch.nn.Module,
noise_scheduler: Any,
optimizer: torch.optim.Optimizer,
train_dataloader: torch.utils.data.DataLoader,
lr_scheduler: Any
) -> None:
"""
The main training loop for the diffusion model.

Args:
config (Any): Configuration object containing various parameters for training.
model (torch.nn.Module): The model to be trained.
noise_scheduler (Any): Scheduler for adding noise during the forward diffusion process.
optimizer (torch.optim.Optimizer): The optimizer used for training the model.
train_dataloader (torch.utils.data.DataLoader): The dataloader providing training batches.
lr_scheduler (Any): Learning rate scheduler to adjust the learning rate during training.

"""
# initialize the accelerator for distributed or mixed-precision training
accelerator = Accelerator(
mixed_precision=config.mixed_precision,
gradient_accumulation_steps=config.gradient_accumulation_steps,
log_with=config.logger_name,
project_dir=os.path.join(config.output_dir, "logs"),
)

# set up tracking if this is the main process
if accelerator.is_main_process:
if config.output_dir is not None:
os.makedirs(config.output_dir, exist_ok=True)
accelerator.init_trackers(config.project_name)

model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, lr_scheduler
)

global_step = 0

# training loop over epochs
for epoch in range(config.num_epochs):
progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)
progress_bar.set_description(f"Epoch {epoch}")

# iterate over batches in the dataloader
for step, batch in enumerate(train_dataloader):
clean_images = batch["images"]
bs = clean_images.shape[0]

# sample a random timestep for each image
timesteps = torch.randint(
0, noise_scheduler.config.num_train_timesteps, (bs,), device=clean_images.device,
dtype=torch.int64
)

# add noise to the clean images according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_images, noise = add_noise(clean_images, noise_scheduler, timesteps)
with accelerator.accumulate(model):
# predict the noise residual from the noisy images
noise_pred = model(noisy_images, timesteps, return_dict=False)[0]
loss = F.mse_loss(noise_pred, noise)
accelerator.backward(loss)

# gradient clipping and optimization step
accelerator.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
lr_scheduler.step()

# reset gradients for the next step
optimizer.zero_grad()

# update progress bar and log the current loss and learning rate
progress_bar.update(1)
current_lr = lr_scheduler.get_last_lr()[0]
metrics = {"loss": loss.detach().item(), "lr": current_lr, "step": global_step}
track_metrics_per_step(metrics,global_step)
progress_bar.set_postfix(**metrics)
global_step += 1

if accelerator.is_main_process:
pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)

# evaluate and save images
if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1:
evaluate(config, epoch, pipeline, global_step)

# save the model checkpoint
if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1:
pipeline.save_pretrained(config.output_dir)

NOTE: The track_metrics_per_step function is defined according to the experimental tracker utilized in each Python Notebook.

Finally, we configure the training loop with the necessary components and launch it using notebook_launcher to start training the diffusion model.

from accelerate import notebook_launcher

# set up the arguments for the training loop function and launch it using the notebook_launcher
args = (config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler)
notebook_launcher(train_loop, args, num_processes=1)

Now that the training setup is complete, we can move on to introducing how to use Wandb and Gradio as experimental trackers. These tools will allow us to monitor and visualize the training process, by tracking key metrics, such as loss and FID score.

2. Wandb’s Built-In Tools

Weights & Biases (Wandb) [2] is a powerful tool designed to track, visualize, and manage machine learning experiments with ease.

Whether you’re training a simple model or managing a complex pipeline, Wandb offers a suite of built-in tools that help you monitor your experiments in real time, compare different runs, and collaborate effectively with your team.

In addition to Wandb, there are several other excellent options available, such as Comet ML, MLFlow, and Neptune, each offering unique features for experiment tracking.

Key Features of Wandb:

  1. Live Tracking of Metrics: Monitor training metrics in real time.
  2. Seamless Integration: Compatible with popular machine learning frameworks like TensorFlow and PyTorch.
  3. Intuitive Dashboard: Provides clear and interactive visualizations of training progress and results.
  4. Run Comparisons and Hyperparameter Tuning: Easily compare different experiment runs and optimize hyperparameters to enhance model performance.

Utilization of Built-In Dashboard from Wandb

We start by installing the required Python library:

!pip install -q -U wandb==0.17.7

Afterwards, we set up the environment variables necessary and the logger used by Accelerate in order to use Wandb as our experiment tracking tool.

import os

os.environ["WANDB_API_KEY"] = "<PLACE_YOUR_API_KEY>"
os.environ["WANDB_INIT_TIMEOUT"] ="300"
config.logger_name = "wandb"

The final step involves defining two methods that are used in the evaluation and training loops to log key metrics to Wandb:

import wandb
def track_metrics_per_epoch(fid_value, epoch):
# log the FID value per epoch to wandb
logs = {"fid": fid_value, "epoch": epoch}
wandb.log(logs)

def track_metrics_per_step(metrics, global_step):
# log the loss and lr value per step to wandb
wandb.log(metrics, step=global_step)

When we are working with a custom metric that isn’t part of the standard metrics, we need to explicitly define it in Wandb. In this case, we define the FID score as a custom metric. Since FID isn’t logged at every step (which is the default in Wandb), we also specify that its step_metric should be based on the epoch rather than individual training steps.

# define custom Ox Axis using epochs for plotting FID Score
wandb.define_metric("epoch")
wandb.define_metric("fid", step_metric="epoch")

NOTE: The custom metrics in Wandb must be defined after Accelerate initializes the trackers, as the project must be already initialized in order to register them.

The Wandb Dashboard can be visualized online under the “Runs” section of the specified project name on the Weights & Biases website:

Wandb’s Built-In Dashboard for Tracking the Training of the Diffusion Model

This is an example of how a run appears in the Wandb dashboard. A useful feature of Wandb is that it stores all the runs of a certain project, allowing you to revisit past experiments easily.

Additionally, you can overlap the plots from multiple runs to compare different experiments side by side. This feature makes it easier to analyze how different model configurations or training strategies impact the metrics over time.

3. Custom Dashboards with Gradio

Gradio [3] is a user-friendly tool designed to make it easy to create and share interactive interfaces for machine learning models.

Whether you’re building a quick prototype or showcasing your model to others, Gradio lets you create custom web interfaces with minimal effort. It’s ideal for researchers and developers who want to provide real-time, interactive demos of their models.

Alongside Gradio, other popular alternatives include Streamlit and Dash, each offering different features and flexibility for developing and sharing machine learning interfaces and applications.

Key Features of Gradio:

  1. No Frontend Code Interface Building: Create interactive model interfaces without needing any frontend coding knowledge or writing complex web code.
  2. Instant Deployment: Launch your model interface with just a few clicks and share it instantly.
  3. Flexible Input/Output Options: Support for a wide range of inputs like images, text, audio, and video, with outputs tailored to match.
  4. Interactive Testing: Enable users to test models with their own data directly in the interface, providing a hands-on experience.

Setting up a Custom Dashboard in Gradio

First, we install the required Python library:

!pip install -q -U gradio==4.41.0 

Next, we refine the tracking methods for the two types of metrics, appending the values to the lists that Gradio uses for its dashboard. Gradio continuously monitors these lists for updates and refreshes the plots every few seconds to reflect the latest data.

def track_metrics_per_epoch(fid_value, epoch):
# append the FID value and epoch number to the respective lists for tracking
fid_values.append(fid_value)
epochs.append(epoch)

def track_metrics_per_step(metrics, global_step):
# append the lr, loss value and step number to the respective lists for tracking
lr_values.append(metrics["lr"])
loss_values.append(metrics["loss"])
steps.append(metrics["step"])

In the following code cell, we set up functions to monitor and visualize key metrics using Gradio. The get_latest_png function finds the most recently generated image grid from the evaluation phase, while get_lr_plot, get_loss_plot, and get_fid_plot create line plots for learning rate, loss, and FID score, respectively. These plots help us track the progress of our model during training. The get_image_paths function retrieves the latest generated and reference image grids for visual comparison.

from accelerate import notebook_launcher
import pandas as pd
import gradio as gr
import numpy as np
import re
from typing import Tuple


def get_latest_png() -> str | None:
"""
Finds and returns the latest PNG file containing the generated images by the model in the evaluation phase,
based on the maximum epoch number in the configuration directory.
The PNG files are following the naming convention that includes an epoch number, such as 'generated_epoch_0010.png'.

Returns:
str | None: The full path to the PNG file with the highest epoch number, or None if the directory does not exist or no PNG files are found.
"""
# Regular expression to match files like 'generated_epoch_0010.png'
pattern = re.compile(r"(\d+)\.png$")
latest_epoch = -1
latest_file = None
test_dir = os.path.join(config.output_dir, "samples")

# Check if the directory exists, if not return None
if not os.path.exists(test_dir):
return None

# Iterate over all files in the directory
for file_name in os.listdir(test_dir):
match = pattern.search(file_name)
if match:
epoch = int(match.group(1)) # Extract epoch number from file name
if epoch > latest_epoch:
latest_epoch = epoch
latest_file = file_name

if latest_file:
return os.path.join(test_dir, latest_file)
else:
return None

def get_lr_plot() -> gr.LinePlot:
"""
Generates a line plot of the learning rate over training steps.

Returns:
gr.LinePlot: A Gradio LinePlot component displaying learning rate over steps.
"""
data = pd.DataFrame({"Steps": steps, "Learning Rate": lr_values})
plot = gr.LinePlot(
value=data,
x="Steps",
y="Learning Rate",
title="Learning Rate per Step",
width=600,
height=350,
)
return plot

def get_loss_plot() -> gr.LinePlot:
"""
Generates a line plot of the loss over training steps.

Returns:
gr.LinePlot: A Gradio LinePlot component displaying loss over steps.
"""
data = pd.DataFrame({"Steps": steps, "Loss": loss_values})
plot = gr.LinePlot(
value=data,
x="Steps",
y="Loss",
title="Loss per Step",
width=600,
height=350,
)
return plot

def get_fid_plot() -> gr.LinePlot:
"""
Generates a line plot of the FID score over epochs.

Returns:
gr.LinePlot: A Gradio LinePlot component displaying FID score over epochs.
"""
data = pd.DataFrame({"Epochs": epochs, "FID": fid_values})
plot = gr.LinePlot(
value=data,
x="Epochs",
y="FID",
title="FID per Epoch",
width=600,
height=350,
)
return plot

def get_image_paths() -> Tuple[str, str]:
"""
Retrieves the paths for the generated and reference image grids.

Returns:
Tuple[str, str]: A tuple containing the paths to the generated and reference image grids.
"""
return get_latest_png(), reference_grid_path

Finally, we create a Gradio interface to display these metrics visually. The interface is organized into rows for line plots of learning rate, loss, and FID score, as well as for displaying reference and generated image grids.

The plots for learning rate and loss are updated every second, while the FID score and generated images are updated every few epochs since they change less frequently. Finally, the Gradio interface is launched, providing a real-time dashboard for monitoring the training process.

# initialize the metrics
lr_values = []
loss_values = []
fid_values = []
steps = []
epochs = []

# create the Gradio interface for displaying plots and images
with gr.Blocks() as demo:
# row for learning rate plot
with gr.Row():
lr_plot = gr.LinePlot()

# row for loss plot
with gr.Row():
loss_plot = gr.LinePlot()

# row for FID plot
with gr.Row():
fid_plot = gr.LinePlot()

# row for image grid captions
with gr.Row():
gr.Markdown("Reference Images Grid")
gr.Markdown("Generated Images Grid")

# row for displaying the reference and generated image grids
with gr.Row():
reference_image_display = gr.Image(height=400)
generated_image_display = gr.Image(height=400)

# load and update the plots and images periodically
demo.load(get_lr_plot, None, lr_plot, every=1)
demo.load(get_loss_plot, None, loss_plot, every=1)

# the FID Score and the generated images are updated less frequently as they are computed only each few epochs
demo.load(get_fid_plot, None, fid_plot, every=5)
demo.load(get_image_paths, None, [generated_image_display, reference_image_display], every=5)

# launch the Gradio interface
demo.launch()

The Gradio dashboard is rendered directly in the Python Notebook:

Gradio Custom Dashboard for Tracking the Training of the Diffusion Model

The Gradio dashboard gives us everything we need to monitor our experiment’s progress effectively. Additionally, you can share it publicly on Hugging Face Spaces, allowing for easy collaboration and showcasing of your work.

4. Comparative Analysis

In this section, we’ll compare Wandb and Gradio based on four important factors: efficiency, flexibility, user experience, and collaboration.

Efficiency (Time and Effort)

Wandb is highly efficient for tracking and logging experiments, offering quick setup, automated logging, and real-time updates that save significant time during training.

It’s particularly effective for monitoring experiments and persisting data across runs, as Wandb automatically saves and organizes your results, making it ideal for in-depth tracking over longer periods.

Gradio excels in rapidly building and deploying interactive model interfaces with minimal coding, making it ideal for quick model showcases and fast experiment overviews. While no frontend knowledge is required, you do need to write your Gradio interface in Python to set up your dashboard, which differs from Wandb’s out-of-the-box integration.

Additionally, if you need to persist data or track progress over time, you’ll need to implement your own persistence solution, as Gradio does not include built-in persistence features like Wandb.

In terms of efficiency, Wandb is best for long-term, detailed experiment tracking, while Gradio is better suited for quickly creating and sharing interfaces for immediate insights.

Flexibility (Interface Customization Features)

Gradio is the clear choice when it comes to interface customization and visuals. It excels in enabling you to design and tailor model interfaces to meet your specific needs.

You can easily adjust layouts, inputs, and overall appearance, making it simple to incorporate features like live updates of images used in the latest FID computation directly into your dashboard.

On the other hand, Wandb focuses more on customizing how you log and visualize metrics. It allows you to tailor dashboards, define custom metrics, and create detailed plots, but it doesn’t offer the same flexibility in interface design as Gradio.

While Wandb provides strong tools for managing and analyzing experiments, it’s less suited for the custom visualizations and interface adjustments that Gradio handles so effectively.

User Experience (Learning Curve)

Both tools are user-friendly but cater to different user needs.

Wandb is easier to grasp for basic tasks, especially if you’re already familiar with machine learning frameworks. Its straightforward setup and clear documentation make it easy to get started, though mastering its full range of features can take time, particularly for more advanced tracking and analysis.

Gradio, on the other hand, is intuitive and straightforward when it comes to creating interactive demos. You can quickly build interfaces without needing web development experience, making it ideal for sharing models with a wide audience. However, Gradio’s dashboard update mechanism can be a bit confusing for beginners, particularly when trying to set up live updates and custom visualizations.

Collaboration (Sharing within a Research Group)

Both Wandb and Gradio are strong in collaboration, but in different ways.

Wandb is built for team collaboration, making it easy to share experiment results, dashboards, and run comparisons within a group. It’s ideal for managing multiple runs and experiments in a team environment.

Gradio focuses on sharing interactive demos through simple links, making it easy to get feedback from both technical and non-technical stakeholders.

Wandb is best for collaborative experiment management, while Gradio is great for sharing and demonstrating models interactively.

5. Choosing the Right Tool: Pros, Cons, and Recommendations

Wandb is best suited for those who need detailed experiment tracking, automated logging, and comprehensive team collaboration features. It’s perfect for research projects that require careful monitoring of multiple metrics and the ability to compare different runs easily. However, it may have a slightly steeper learning curve, especially for users new to machine learning frameworks.

Gradio is ideal for quickly creating and sharing interactive model interfaces. It’s perfect for projects where the main goal is to demonstrate the model to others, gather feedback, or share results with non-technical stakeholders. While Gradio excels in ease of use and simplicity, it’s less focused on in-depth experiment tracking and analysis.

Recommendations:

  • Use Wandb if your priority is detailed experiment tracking, comparison, and collaboration within a research team.
  • Use Gradio if your focus is on creating and sharing interactive demos of your experiments, especially for showcasing the models to a wider audience such as other researchers, potential collaborators, or key stakeholders.
  • In many cases, using both tools together can provide an all-encompassing solution, combining the strengths of Wandb’s tracking and Gradio’s interactive presentation.

Conclusion

This article showcased the use of Gradio and Wandb as experimental trackers in the context of diffusion models training.

We walked through the process of training a diffusion model, covering the basics of how these models function.

We explored how to utilize Wandb’s built-in features and how to build a custom Gradio dashboard for tracking your training experiments.

Additionally, we compared Wandb and Gradio across four key criteria, highlighting each tool's strengths and weaknesses.

🔗 Check out the code on GitHub and support us with a ⭐️

Enjoyed This Article?

Join the Decoding ML Newsletter for battle-tested content on designing, coding, and deploying production-grade ML & MLOps systems. Every week. For FREE

References

[1] Diffusers Huggingface Library

[2] Weights & Biases

[3] Gradio

Images

If not otherwise stated, all images are created by the author.

--

--

Anca Ioana Muscalagiu
Decoding ML

Software Engineer | Quantum Machine Learning Researcher