Hierarchical Image Pyramid Transformers

Ayyuce Demirbas
12 min readFeb 5, 2024

--

This paper introduces the Hierarchical Image Pyramid Transformer (HIPT), a novel Vision Transformer (ViT) architecture designed for analyzing gigapixel whole-slide images (WSIs) in computational pathology. HIPT leverages the inherent hierarchical structure of WSIs to learn high-resolution image representations through self-supervised learning. Pretrained on a large dataset covering 33 cancer types and evaluated across multiple slide-level tasks, HIPT demonstrates superior performance in cancer subtyping and survival prediction, showcasing the potential of self-supervised learning models to capture crucial inductive biases and phenotypes within the tumor microenvironment. Read the full paper here: https://arxiv.org/pdf/2206.02647.pdf

https://arxiv.org/pdf/2206.02647.pdf

This figure illustrates the hierarchical structure of whole-slide images (WSIs) used in computational pathology. On the left, it shows the multilevel approach where a large tissue image (150,000 x 150,000 pixels) is broken down into smaller, more manageable sections: first into 4096 x 4096 regions showing tissue phenotypes, then further into 256 x 256 cellular organization patches, and finally into the smallest 16 x 16 cellular features. On the right, it demonstrates how a 256 x 256 image is composed of a sequence of 256 smaller 16 x 16 tokens, and in turn, how each 256 x 256 image can be a part of a larger, disjoint sequence of 256 x 256 tokens within a 4096 x 4096 region. This hierarchical tokenization allows for the handling and analysis of very large images at different resolutions and scales.

The model consists of three stages of hierarchical aggregation, starting with bottom-up aggregation from 16x16 visual tokens in their respective 256x256 and 4096x4096 windows to eventually form the slide-level representation. The key components of the HIPT model can be written as follows:

1. Hierarchical Aggregation: HIPT aggregates visual tokens at the cell-, patch-, and region-level to form slide representations. This hierarchical approach is motivated by the use of hierarchical representations in natural language processing, where embeddings can be aggregated at different levels to form document representations. Similarly, in the context of WSI, the hierarchical aggregation allows the model to capture information at different levels of granularity, from individual cells to broader tissue structures.

2. Transformer Self-Attention: To model important dependencies between visual concepts at each stage of aggregation, HIPT adapts Transformer self-attention as a permutation-equivariant aggregation layer. This enables the model to capture complex relationships between visual tokens and learn representations that encode both local and global context within the images.

3. Pretraining and Self-Supervised Learning: HIPT is pretrained using self-supervised learning on a large dataset of gigapixel WSIs across 33 cancer types. The model leverages two levels of self-supervised learning to learn high-resolution image representations, and uses student-teacher knowledge distillation to pretrain each aggregation layer with self-supervised learning on regions as large as 4096x4096.

4. Performance and Applications: The results of the study demonstrate that HIPT with hierarchical pretraining outperforms current state-of-the-art methods on slide-level tasks. The model’s performance is evaluated on 9 slide-level tasks, including cancer subtyping and survival prediction, and shows superior performance in capturing broader prognostic features in the tissue microenvironment.

https://arxiv.org/pdf/2206.02647.pdf

From left to right, the figure shows three levels of aggregation:

  1. Cell-Level Aggregation: Individual cells are represented by 16 px tokens, which are then aggregated into a patch-level representation using a ViT256–16 model, followed by global pooling to obtain a single vector representation.
  2. Patch-Level Aggregation: The 256 px patches are processed using a larger ViT variant designed for 256 px input, again followed by a pooling layer to summarize the patch-level features into a region-level representation.
  3. Region-Level Aggregation: Finally, the 4096 px regions are aggregated, this time using a ViT that takes the entire region as input, leading to a global attention pooling layer that provides the slide-level representation.

This hierarchical process allows the model to handle the immense scale of WSIs by breaking down the problem into manageable parts and focusing on different levels of details, from cellular to tissue structures.

The authors share their implementation. Let’s get a glimpse at some of these scripts.

The script below leverages the implementation of Vision Transformers (ViTs) specifically for high-resolution image analysis, incorporating several advanced features and techniques:

1. Truncated Normal Initialization: A technique for initializing neural network weights in a manner that avoids large deviations from the mean, ensuring stability in early training phases.

2. Drop Path: A regularization method that randomly drops paths in the network during training to improve generalization by simulating a thinner network, akin to dropout but for residual connections.

3. Multi-Layer Perceptron (MLP) Module: Defines a simple two-layer MLP with a GELU activation function and dropout, used within the transformer blocks to process features.

4. Attention Mechanism: Implements the self-attention mechanism with optional bias and scaling, crucial for capturing global dependencies in the input data.

5. Transformer Block: Combines the norm layer, attention mechanism, and MLP into a cohesive block, with optional path dropout for regularization.

6. VisionTransformer4K: A specialized version of the Vision Transformer designed for very high-resolution images, incorporating techniques like positional embedding interpolation for adapting to different image sizes, and a structure optimized for processing large-scale images.

7. Utility Functions: Includes functions for truncated normal weight initialization, drop path simulation, and parameter counting to aid in model setup and analysis.

import argparse
import os
import sys
import datetime
import time
import math
import json
from pathlib import Path

import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torchvision import models as torchvision_models

import vision_transformer as vits
from vision_transformer import DINOHead

import math
from functools import partial

import torch
import torch.nn as nn


def _no_grad_trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1. + math.erf(x / math.sqrt(2.))) / 2.

if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2)

with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)

# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)

# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()

# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)

# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor


def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
# type: (Tensor, float, float, float, float) -> Tensor
return _no_grad_trunc_normal_(tensor, mean, std, a, b)



def drop_path(x, drop_prob: float = 0., training: bool = False):
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output


class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob

def forward(self, x):
return drop_path(x, self.drop_prob, self.training)


class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)

def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x


class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5

self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)

def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]

attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)

x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x, attn


class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

def forward(self, x, return_attention=False):
y, attn = self.attn(self.norm1(x))
if return_attention:
return attn
x = x + self.drop_path(y)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x


class VisionTransformer4K(nn.Module):
""" Vision Transformer 4K """
def __init__(self, num_classes=0, img_size=[224], input_embed_dim=384, output_embed_dim = 192,
depth=12, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, num_prototypes=64, **kwargs):
super().__init__()
embed_dim = output_embed_dim
self.num_features = self.embed_dim = embed_dim
self.phi = nn.Sequential(*[nn.Linear(input_embed_dim, output_embed_dim), nn.GELU(), nn.Dropout(p=drop_rate)])
num_patches = int(img_size[0] // 16)**2
print("# of Patches:", num_patches)

self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
self.pos_drop = nn.Dropout(p=drop_rate)

dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList([
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
for i in range(depth)])
self.norm = norm_layer(embed_dim)

# Classifier head
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()

trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
self.apply(self._init_weights)

def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)

def interpolate_pos_encoding(self, x, w, h):
npatch = x.shape[1] - 1
N = self.pos_embed.shape[1] - 1
if npatch == N and w == h:
return self.pos_embed
class_pos_embed = self.pos_embed[:, 0]
patch_pos_embed = self.pos_embed[:, 1:]
dim = x.shape[-1]
w0 = w // 1
h0 = h // 1
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
w0, h0 = w0 + 0.1, h0 + 0.1
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
mode='bicubic',
)
assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)

def prepare_tokens(self, x):
#print('preparing tokens (after crop)', x.shape)
self.mpp_feature = x
B, embed_dim, w, h = x.shape
x = x.flatten(2, 3).transpose(1,2)

x = self.phi(x)


# add the [CLS] token to the embed patch tokens
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)

# add positional encoding to each token
x = x + self.interpolate_pos_encoding(x, w, h)

return self.pos_drop(x)

def forward(self, x):
x = self.prepare_tokens(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x[:, 0]

def get_last_selfattention(self, x):
x = self.prepare_tokens(x)
for i, blk in enumerate(self.blocks):
if i < len(self.blocks) - 1:
x = blk(x)
else:
# return attention of the last block
return blk(x, return_attention=True)

def get_intermediate_layers(self, x, n=1):
x = self.prepare_tokens(x)
# we return the output tokens from the `n` last blocks
output = []
for i, blk in enumerate(self.blocks):
x = blk(x)
if len(self.blocks) - i <= n:
output.append(self.norm(x))
return output

def vit4k_xs(patch_size=16, **kwargs):
model = VisionTransformer4K(
patch_size=patch_size, input_embed_dim=384, output_embed_dim=192,
depth=6, num_heads=6, mlp_ratio=4,
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model

def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)

The code script below outlines the implementation of loading and evaluating Vision Transformer (ViT) models for image analysis, specifically designed for high-resolution images like those in computational pathology. It defines functions to:

1. Load pre-trained ViT models (`get_vit256` and `get_vit4k`) with options for different architectures and device settings, initializing them in evaluation mode with no gradients computation.
2. Apply transformations (`eval_transforms`) for model evaluation, normalizing images with a specific mean and standard deviation.
3. Convert batches of image tensors into a single PIL image (`roll_batch2img`) or a numpy array (`tensorbatch2im`), facilitating the handling of image data for visualization or further processing.

### Dependencies
# Base Dependencies
import argparse
import colorsys
from io import BytesIO
import os
import random
import requests
import sys

# LinAlg / Stats / Plotting Dependencies
import cv2
import h5py
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.patches import Polygon
import numpy as np
from PIL import Image
from PIL import ImageFont
from PIL import ImageDraw
from scipy.stats import rankdata
import skimage.io
from skimage.measure import find_contours
from tqdm import tqdm
import webdataset as wds

# Torch Dependencies
import torch
import torch.multiprocessing
import torchvision
from torchvision import transforms
from einops import rearrange, repeat
torch.multiprocessing.set_sharing_strategy('file_system')

# Local Dependencies
import vision_transformer as vits
import vision_transformer4k as vits4k

def get_vit256(pretrained_weights, arch='vit_small', device=torch.device('cuda:0')):
r"""
Builds ViT-256 Model.

Args:
- pretrained_weights (str): Path to ViT-256 Model Checkpoint.
- arch (str): Which model architecture.
- device (torch): Torch device to save model.

Returns:
- model256 (torch.nn): Initialized model.
"""

checkpoint_key = 'teacher'
device = torch.device("cpu")
model256 = vits.__dict__[arch](patch_size=16, num_classes=0)
for p in model256.parameters():
p.requires_grad = False
model256.eval()
model256.to(device)

if os.path.isfile(pretrained_weights):
state_dict = torch.load(pretrained_weights, map_location="cpu")
if checkpoint_key is not None and checkpoint_key in state_dict:
print(f"Take key {checkpoint_key} in provided checkpoint dict")
state_dict = state_dict[checkpoint_key]
# remove `module.` prefix
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
# remove `backbone.` prefix induced by multicrop wrapper
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
msg = model256.load_state_dict(state_dict, strict=False)
print('Pretrained weights found at {} and loaded with msg: {}'.format(pretrained_weights, msg))

return model256


def get_vit4k(pretrained_weights, arch='vit4k_xs', device=torch.device('cuda:1')):
r"""
Builds ViT-4K Model.

Args:
- pretrained_weights (str): Path to ViT-4K Model Checkpoint.
- arch (str): Which model architecture.
- device (torch): Torch device to save model.

Returns:
- model256 (torch.nn): Initialized model.
"""

checkpoint_key = 'teacher'
device = torch.device("cpu")
model4k = vits4k.__dict__[arch](num_classes=0)
for p in model4k.parameters():
p.requires_grad = False
model4k.eval()
model4k.to(device)

if os.path.isfile(pretrained_weights):
state_dict = torch.load(pretrained_weights, map_location="cpu")
if checkpoint_key is not None and checkpoint_key in state_dict:
print(f"Take key {checkpoint_key} in provided checkpoint dict")
state_dict = state_dict[checkpoint_key]
# remove `module.` prefix
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
# remove `backbone.` prefix induced by multicrop wrapper
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
msg = model4k.load_state_dict(state_dict, strict=False)
print('Pretrained weights found at {} and loaded with msg: {}'.format(pretrained_weights, msg))

return model4k


def eval_transforms():
"""
"""
mean, std = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
eval_t = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean = mean, std = std)])
return eval_t


def roll_batch2img(batch: torch.Tensor, w: int, h: int, patch_size=256):
"""
Rolls an image tensor batch (batch of [256 x 256] images) into a [W x H] Pil.Image object.

Args:
batch (torch.Tensor): [B x 3 x 256 x 256] image tensor batch.

Return:
Image.PIL: [W x H X 3] Image.
"""
batch = batch.reshape(w, h, 3, patch_size, patch_size)
img = rearrange(batch, 'p1 p2 c w h-> c (p1 w) (p2 h)').unsqueeze(dim=0)
return Image.fromarray(tensorbatch2im(img)[0])


def tensorbatch2im(input_image, imtype=np.uint8):
r""""
Converts a Tensor array into a numpy image array.

Args:
- input_image (torch.Tensor): (B, C, W, H) Torch Tensor.
- imtype (type): the desired type of the converted numpy array

Returns:
- image_numpy (np.array): (B, W, H, C) Numpy Array.
"""
if not isinstance(input_image, np.ndarray):
image_numpy = input_image.cpu().float().numpy() # convert it into a numpy array
#if image_numpy.shape[0] == 1: # grayscale to RGB
# image_numpy = np.tile(image_numpy, (3, 1, 1))
image_numpy = (np.transpose(image_numpy, (0, 2, 3, 1)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling
else: # if it is a numpy array, do nothing
image_numpy = input_image
return image_numpy.astype(imtype)

And this script defines the HIPT_4K model, integrating Vision Transformer models for processing high-resolution images. It loads pre-trained ViT models for 256x256 and 4K resolutions, applying them hierarchically to input images. The process involves cropping the input image into 256x256 patches, extracting features from each patch using ViT_256, and then feeding these features into ViT_4K to obtain a global representation. This hierarchical approach enables efficient handling of non-square, high-resolution images, optimizing for detailed feature extraction at both local and global scales, aligning with the paper’s methodology on leveraging hierarchical structures for image analysis.

import torch
from einops import rearrange, repeat
from HIPT_4K.hipt_model_utils import get_vit256, get_vit4k

class HIPT_4K(torch.nn.Module):
"""
HIPT Model (ViT_4K-256) for encoding non-square images (with [256 x 256] patch tokens), with
[256 x 256] patch tokens encoded via ViT_256-16 using [16 x 16] patch tokens.
"""
def __init__(self,
model256_path: str = 'path/to/Checkpoints/vit256_small_dino.pth',
model4k_path: str = 'path/to/Checkpoints/vit4k_xs_dino.pth',
device256=torch.device('cuda:0'),
device4k=torch.device('cuda:1')):

super().__init__()
self.model256 = get_vit256(pretrained_weights=model256_path).to(device256)
self.model4k = get_vit4k(pretrained_weights=model4k_path).to(device4k)
self.device256 = device256
self.device4k = device4k
self.patch_filter_params = patch_filter_params

def forward(self, x):
"""
Forward pass of HIPT (given an image tensor x), outputting the [CLS] token from ViT_4K.
1. x is center-cropped such that the W / H is divisible by the patch token size in ViT_4K (e.g. - 256 x 256).
2. x then gets unfolded into a "batch" of [256 x 256] images.
3. A pretrained ViT_256-16 model extracts the CLS token from each [256 x 256] image in the batch.
4. These batch-of-features are then reshaped into a 2D feature grid (of width "w_256" and height "h_256".)
5. This feature grid is then used as the input to ViT_4K-256, outputting [CLS]_4K.

Args:
- x (torch.Tensor): [1 x C x W' x H'] image tensor.

Return:
- features_cls4k (torch.Tensor): [1 x 192] cls token (d_4k = 192 by default).
"""
batch_256, w_256, h_256 = self.prepare_img_tensor(x) # 1. [1 x 3 x W x H].
batch_256 = batch_256.unfold(2, 256, 256).unfold(3, 256, 256) # 2. [1 x 3 x w_256 x h_256 x 256 x 256]
batch_256 = rearrange(batch_256, 'b c p1 p2 w h -> (b p1 p2) c w h') # 2. [B x 3 x 256 x 256], where B = (1*w_256*h_256)

features_cls256 = []
for mini_bs in range(0, batch_256.shape[0], 256): # 3. B may be too large for ViT_256. We further take minibatches of 256.
minibatch_256 = batch_256[mini_bs:mini_bs+256].to(self.device256, non_blocking=True)
features_cls256.append(self.model256(minibatch_256).detach().cpu()) # 3. Extracting ViT_256 features from [256 x 3 x 256 x 256] image batches.

features_cls256 = torch.vstack(features_cls256) # 3. [B x 384], where 384 == dim of ViT-256 [ClS] token.
features_cls256 = features_cls256.reshape(w_256, h_256, 384).transpose(0,1).transpose(0,2).unsqueeze(dim=0)
features_cls256 = features_cls256.to(self.device4k, non_blocking=True) # 4. [1 x 384 x w_256 x h_256]
features_cls4k = self.model4k.forward(features_cls256) # 5. [1 x 192], where 192 == dim of ViT_4K [ClS] token.
return features_cls4k

--

--