Diffusion Transformer and Rectified Flow Transformer for Conditional Image Generation
Overview
- Introduction
- Diffusion Transformer (DiT)
- Rectified Flow Transformer
- Distillation for fast inference
- Summary
- References by topic
- Appendix
- Introduction
Over the past several years, the quality, aesthetics, and prompt adherence of image generation neural networks has improved rapidly. Understanding a few key technical innovations can pinpoint why certain models greatly outperform prior deep learning diffusion models. In this article, we will focus on the Diffusion Transformer (DiT) and explore key advances for high quality class-conditioned and text-to-image generation.
- Diffusion Transformer (DiT)
DiT introduced by William Peebles and Saining Xie in 2023 marked a significant departure from earlier diffusion models (Peebles and Xie 2023). Diffusion models used for image generation included two key design features, a convolutional U-Net backbone (Ho et al., 2020) and more recently, a latent diffusion architecture (Rombach et al., 2022). DiT retains the latent diffusion architecture as input to a vision transformer (ViT) backbone (Figure 1). DiT also incorporates a timestep and class label as embeddings and adaptive layer normalization (adaLN) to inject conditional information into the model. The model learns to remove noise from a latent prediction, conditioned over the timestep and class label embeddings.
The forward diffusion process begins with an input latent representation z = E(x) obtained from a frozen encoder, E. At each timestep, Gaussian noise is incrementally added to z. The model is trained to predict the noise at each timestep, conditioned on the class label. During training, the mean corresponds to the noise prediction, and the model minimizes the mean squared error (MSE) loss. The variance corresponds to the diagonal covariance, and a KL divergence loss is used to optimize this term. Classifier-free guidance randomly drops out c during training and replaces it with a learned “null” embedding ∅. Image generation is then possible via the reverse diffusion process. The input is initialized with random Gaussian noise and then iteratively denoised at each timestep. At step t, the model takes as input the current noisy latent variable, the timestep, and the class label for conditioning. Finally, the denoised latent representation is decoded into an image using the frozen decoder D, x = D(z).
In terms of model scaling, Peebles et al. experimented with configurations ranging from 33 million to 675 million parameters. DiT trained on ImageNet achieved a state-of-the-art result of 2.27 FID on the class conditional 256 × 256 generation benchmark.
Below is the code to run the official Pytorch DiT model for 256 × 256 generation. I am using the L4 GPU from NVIDIA on Google Colaboratory. First you will need to setup the environment and import dependencies.
!git clone https://github.com/facebookresearch/DiT.git
import DiT, os
os.chdir('DiT')
os.environ['PYTHONPATH'] = '/env/python:/content/DiT'
!pip install diffusers timm --upgrade
# DiT imports:
import torch
from torchvision.utils import save_image
from diffusion import create_diffusion
from diffusers.models import AutoencoderKL
from download import find_model
from models import DiT_XL_2
from PIL import Image
from IPython.display import display
torch.set_grad_enabled(False)
device = "cuda"
Now load the model.
image_size = 256
vae_model = "stabilityai/sd-vae-ft-ema"
latent_size = int(image_size) // 8
# Load model:
model = DiT_XL_2(input_size=latent_size).to(device)
state_dict = find_model(f"DiT-XL-2-{image_size}x{image_size}.pt")
model.load_state_dict(state_dict)
model.eval() # important!
vae = AutoencoderKL.from_pretrained(vae_model).to(device)
Sample from DiT using class labels from ImageNet 1K. I chose four dog classes to compare the results. The class list can be found here.
seed = 0
torch.manual_seed(seed)
num_sampling_steps = 100
cfg_scale = 4
#ImageNet class labels
class_labels = 162,247,248,254
samples_per_row = 2
# Create diffusion object:
diffusion = create_diffusion(str(num_sampling_steps))
# Create sampling noise:
n = len(class_labels)
z = torch.randn(n, 4, latent_size, latent_size, device=device)
y = torch.tensor(class_labels, device=device)
# Setup classifier-free guidance:
z = torch.cat([z, z], 0)
y_null = torch.tensor([1000] * n, device=device)
y = torch.cat([y, y_null], 0)
model_kwargs = dict(y=y, cfg_scale=cfg_scale)
# Sample images:
samples = diffusion.p_sample_loop(
model.forward_with_cfg, z.shape, z, clip_denoised=False,
model_kwargs=model_kwargs, progress=True, device=device
)
samples, _ = samples.chunk(2, dim=0) # Remove null class samples
samples = vae.decode(samples / 0.18215).sample
# Save and display images:
save_image(samples, "sample.png", nrow=int(samples_per_row),
normalize=True, value_range=(-1, 1))
samples = Image.open("sample.png")
display(samples)
The four class-conditioned sampling results are shown in Figure 2 below.
- Rectified Flow Transformer
Stability AI built upon the DiT architecture with the release of Stable Diffusion 3 (SD3) for text-to-image generation. This release introduced several key updates, including an improved training objective, the Multimodal Diffusion Transformer (MMDiT) architecture, and a fine-tuning procedure incorporating human feedback. These modifications are briefly summarized below.
SD3 employs rectified flow (Esser et al., 2024) in the training loss leveraging the concept of optimal transport (OT) to connect the data distribution and noise along straight paths (Liu et al., 2022; Albergo & Vanden-Eijnden, 2022; Lipman et al., 2023). In simple terms, unlike the diffusion process which follows a stochastic trajectory, OT establishes deterministic straight-line paths (Figure 3). In the context of generative modeling, OT defines the mapping between noise and the sample distribution using an ordinary differential equation (ODE). This approach is advantageous because the forward process directly impacts the learned reverse process, improving sampling efficiency. Rectified flow is a specific OT-based objective that redefines the forward process as straight paths between the data distribution and a standard normal distribution. This formulation is reparametrized into a noise prediction objective to align with diffusion training.
The authors of SD3 found that a fixed text representation was not ideal for image generation. Instead, MMDiT incorporates learnable streams for both image and text tokens, mixing text and image encodings inside its operations. The authors state that this enables a two-way flow of information. MMDiT incorporates three different text encoders — two based on CLIP and the other on T5 — to represent textual input. In each MMDiT-Block, Query-Key Normalization was used before calculating the attention matrix to reduce attention-logit growth instability during training and to further simplify fine-tuning. The models range from 800 million to 8 billion parameters.
SD3 was fine-tuned with Direct Preference Optimization (DPO). DPO is an alternative to Reinforcement Learning from Human Feedback, as introduced by Raflailov et al. for language models and by Wallace et al. for diffusion models (Raflailov et al., 2023 and Wallace et al., 2023). DPO has been shown to improve the quality of image generation, prompt adherence, and text generation directly without the need to train a separate rewards model, as is common for Reinforcement Learning approaches. In SD3, DPO fine-tuning is combined with Low-Rank Adaptation (LoRA) matrices in the 2B and 8B parameter models. Results are shown in the Appendix.
- Distillation for fast inference
Current distillation approaches for generative models aim to improve sampling speed, while preserving the iterative refinement capability of diffusion models. Recent model releases have further advanced upon the rectified flow transformer with distillation. Stable Diffusion 3.5 (SD3.5) introduced an 8.1-billion-parameter Large and a distilled version called Large Turbo. The Turbo series leverages adversarial diffusion distillation (ADD) to enable efficient sampling in just 1–4 steps, while maintaining high image quality (Sauer et al., 2023).
Below is the code for running inference with SD3.5 Large Turbo using Hugging Face Diffusers. This is a free model, but it is gated. So you will need to log in with a Hugging Face token.
# Gated model: Login with a HF token with gated access permission
!huggingface-cli login
Now load the model. It will require several minutes to download.
#turbo large with a fixed random seed
import torch
from diffusers import StableDiffusion3Pipeline
pipe = StableDiffusion3Pipeline.from_pretrained("stabilityai/stable-diffusion-3.5-large-turbo", torch_dtype=torch.bfloat16, use_safetensors=True)
Place the model on GPU and run inference. I am using the A100 GPU from NVIDIA on Google Colaboratory.
pipe = pipe.to("cuda")
#provide a fixed starting seed with generator
generator = torch.Generator(device="cuda").manual_seed(0)
prompt = 'A highly detailed, beautiful human face with very high resolution captured in a captivating camera shot that is aesthetically pleasing, the individual is attractive and has a friendly approachable expression.'
#Use a guidance scale of zero for Turbo Large
image = pipe(
prompt,
num_inference_steps=4,
guidance_scale=0.0,
generator=generator
).images[0]
image
SD3.5 Large Turbo achieves similar performance to SD3.5 Large in a few sampling steps. The distilled model retains the iterative refinement ability, while requiring fewer steps to generate high quality results. The results are shown below in Figures 4–6.
ADD employs two loss functions during training: 1) adversarial loss, ensuring outputs lie on the manifold of real images by training the model to fool a discriminator, and 2) distillation loss, using a pre-trained diffusion model as a teacher to guide the model in matching denoised targets. Black Forest Labs has also contributed with open-weight FLUX.1 series. The FLUX.1 models are distilled directly from the non-open-weight 12-billion-parameter pro model. FLUX.1 [schnell] uses ADD, while FLUX.1 [dev] employs guidance distillation to improve sampling efficiency.
Below is the code to run FLUX.1 [schnell]. This model is also gated and requires a Hugging Face token to access.
import torch
from diffusers import FluxPipeline
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
pipe = pipe.to("cuda")
prompt = "An ancient stone tablet partially buried in a desert at dusk, covered in glowing symbols. The inscription reads: 'We will need transformers for artificial intelligence' in runic and futuristic text. The symbols shimmer faintly, casting shadows on the sand. Above, a translucent scroll unfurls with more mysterious runes, while the sky transitions from pink to purple with a surreal moon on the horizon."
image = pipe(
prompt,
guidance_scale=0.0,
num_inference_steps=4,
).images[0]
image
The results from ADD models including SD3.5 Large Turbo and FLUX.1 [schnell] are shown in Figure 7 below.
FLUX.1 [dev] applies guidance distillation to enhance the efficiency of classifier-free guidance models. This two-stage process first trains a student model to replicate the outputs of a frozen teacher model, then progressively distills the student into a version requiring fewer sampling steps (Meng et al., 2023). The results of FLUX.1 [dev] are shown in Figure 7.
- Summary
Diffusion models with transformers improve the state of the art for class label guided image generation. Further advancements with rectified flow have been applied to enhance the quality of text-to-image generation and sampling efficiency. Distilled rectified flow transformers provide highly aesthetic results within a few sampling steps.
- References by topic
Diffusion Transformer (DiT)
Peebles and Xie, Scalable diffusion models with transformers, 2023.
Denoising diffusion
Ho et al., Denoising Diffusion Probabilistic Models, 2020.
Rombach et al., High-Resolution Image Synthesis with Latent Diffusion Models. 2022.
Rectified Flow Transformer
Esser et al., Scaling Rectified Flow Transformers for High-Resolution Image Synthesis, 2024.
Rectified flow
Liu et al.,Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow, 2022.
Albergo and Vanden-Eijnden, Building Normalizing Flows with Stochastic Interpolants, 2023.
Lipman et al,. Flow Matching for Generative Modeling, 2023.
Direct Preference Optimization
Wallace et al., Diffusion Model Alignment Using Direct Preference Optimization, 2023.
Model distillation
Sauer et al., Adversarial Diffusion Distillation, 2023.
Meng et al., On Distillation of Guided Diffusion Models, 2023.
Recent model releases
Introducing Stable Diffusion 3.5 (accessed in Nov. 2024).
Announcing Black Forest Labs (accessed in Nov. 2024).
- Appendix
A. Visual comparison of recent and legacy models
The Hugging Face Diffusers library was used to load 16-bit models on an A100 GPU. Inference was run using 28 sampling steps and a guidance scale of 7. The prompts to generate each of the image panels below are shown at the end of this section.
Select transformer results, arranged from smallest to largest model size.
Select results from legacy U-Net models, arranged from smallest to largest.
Prompt 1
“Astronaut in a jungle, cold color palette, muted colors, detailed, 8k, with the text ‘The future is with diffusion transformers’ integrated into the image.”
Prompt 2
“An ancient stone tablet partially buried in a desert at dusk, covered in glowing symbols. The inscription reads: ‘We will need transformers for artificial intelligence’ in runic and futuristic text. The symbols shimmer faintly, casting shadows on the sand. Above, a translucent scroll unfurls with more mysterious runes, while the sky transitions from pink to purple with a surreal moon on the horizon.”
B. Description of U-Net models
Stable Diffusion 2.0 is an 860 million parameters U-Net plus OpenCLIP-ViT/H. The model uses a latent diffusion formulation with continuous-time diffusion. Starting with 2.0, models were trained from scratch on a filtered subset of LAION-5B filtered using the LAION-NSFW classifier. XL described in the paper from Podell et al. in 2023, scales up the same framework with larger cross-attention context, training with multiple aspect ratios, a larger U-Net with 2.6B parameters, plus two text encoders CLIP ViT-L & OpenCLIP ViT-bigG. SDXL 1.0 uses an ensemble pipeline with the final output created by running two models and aggregating the results.
Reference
Podell et al., SDXL: Improving Latent Diffusion Models for High-Resolution Image Synthesis, 2023.
For more of my writing, follow me and click below.
Please clap if you enjoyed the article. -Erik