Yoga Assistant Part 1: 圖像分類模型(CNN/ViT)

楊哲寧 (Jeff Yang)
28 min readMay 1, 2024

--

這陣子工作之餘應朋友之託訓練了一個瑜伽姿勢的分類模型,結合LLMs打造Yoga Assistant Chatbot,並通過AWS Service 來部署,想將開發的經驗寫成一個系列,希望能維持一週~二週一更的速度(先立flag)。此篇是這個系列的首篇,將詳細介紹如何迅速構建並訓練一個可靠的圖像分類模型。儘管近期談到AI不外乎都是生成式AI相關技術與應用,然而產業界許多的底層任務還是透過較小型的模型來完成,擁有扎實 CV/NLP 背景知識不管是在工作、生活上帶來不小幫助,準備好了就讓我們開始吧!

Source Coe: yoga-model-hub

資料收集與分析

定義明確、標註一致(Consistency)的資料集是所有任務的第一步也是最重要的一步。一開始收到朋友的需求時,原本是打算寫個爬蟲去搜集資料,但礙於沒有瑜珈相關的背景知識,不確定該如何定義分類的類別與架構,好在隨後就發現了Yoga-82 的文獻資料集

Yoga-82 資料集將瑜伽姿勢拆分成層狀(Hierarchical)的標注結構,一共三層,第一層有6個類別,第二層20個類別,第三層也是最後一層有82個類別。取決於最終用途,我們可以考慮建造一個 三層 Hierarchical Classifier 或是只分類最後一層,然而就算只需要最後的分類,我們還是可以將前兩層當作 Auxiliary Classifier 來提昇收斂速度與模型最終表現。

Source Yoga-82 propose the concept of fine-grained hierarchical pose classification and propose a large-scale pose dataset called Yoga-82, comprising of multi-level class hierarchy based on the visual appearance of the pose.

模型建構

這裡我會介紹如何使用PyTorch Image Models以及 DINOv2 來客製化我們需要的模型架構,需要一些背景知識,對CNN原理有興趣的讀者可以參閱早期寫的一些文章,雖然已經歷史久遠,但許多基礎概念到如今都還是通用的!

CNN原理: Reference

2019年前的 CNN 模型架構: Reference

ConvNeXt 原理概述

如果只對實踐步驟感興趣可以直接跳過這段~

此範例我們用ConvNeXt,更精確來說是 convnext_small.in12k_ft_in1k_384 (Model Card),選擇此模型單純因為以參數量50M左右的模型來說,其表現相當不錯。ConvNeXt 由 Facebook AI Research (FAIR), UC Berkeley 發表,其目標在於提供能與ViT 競爭的 ConvNets(CNN) 架構。

作者在原文中提出原始的ViT有許多問題(例如沒有 CNN 的 Inductive Biases),之後的一系列改進主要也是要帶回CNN的優勢,但所需的運算量與參數量卻更大,訓練也較為不易,除此之外,ViT模型展現出來的優勢也有一部分必須歸功於複雜的訓練策略、參數設計。

ConvNeXt 以ResNet為模板,參考 ViT 一系列的設計,如提高 Stage/Stem 層參數比,拉大Stem Kernel Size,使用Depthwise Convolution 同時加大Width來彌補 FLOPs的降低,激活函數改用GELU等等,大幅提升了 ResNet架構的表現。

Source

ConvNeXt V2 近一步導入的Transformer-based Model 預訓練時常用的 self-supervised learning,加以改良提出 Fully Convolutional Masked Autoencoder (FCMAE)。

Sparse convolution-based ConvNeXt encoder and a lightweight ConvNeXt block decoder (Source)

然而導入FCMAE並無法有效提升模型表現,作者們認為是因為許多激活後的 Feature Map 塌陷或飽和。

Feature cosine distance analysis (source)

為了解決此問題,提出了 Global Response Normalization (GRN)

# gamma, beta: learnable affine transform parameters 
# X: input of shape (N,H,W,C)

# L2-norm-based feature aggregation, can be viewed as a simple pooling layer
gx = torch.norm(X, p=2, dim=(1,2), keepdim=True)
# Standard divisive normalization
nx = gx / (gx.mean(dim=-1, keepdim=True)+1e-6)
# Calibrate with computed feature normalization scores
return gamma * (X * nx) + beta + X

結論,ConvNeXt V2 並無並無顯著更動模型架構,但提出 FCMAE 、GRN, 兩者搭配使用表現更好。

有興趣讀者可以參考原論文:

ConvNeXt 實作 (PyTorch Image Models)

PyTorch Image Models (timm) 現在隸屬於Hugging Face Hub,提供許多模型架構、預訓練權重,其在 0.6.x版本後大幅更新API介面並導入了許多新功能,並且持續在更新與優化。

首先我們先來搭建 Backbone,drop_path_rate 是之前在優化訓練的文章(reference)中介紹的 Stochastic Depth,features_only 讓我們只使用feature extractor(不用heads),out_indices 可以靈活的取出不同 layers 的輸出。

import timm
timm_model = "convnext_small.in12k_ft_in1k_384"
backbone = timm.create_model(
model_name=timm_model,
pretrained=True,
drop_rate=0.1,
drop_path_rate=0.5,
features_only=True,
out_indices=[-2, -1], #因應Hierarchical Annotation 架構考量
)

tensor = torch.randn(1, 3, 256, 256)
outupt = backbone(tensor)
# torch.Size([1, 384, 16, 16])
print(outupt[0].shape)
# torch.Size([1, 768, 8, 8])
print(outupt[1].shape)

有了Backbone後我們就能快速搭建完整的模型架構,下方有一些實踐稍微複雜是因為實驗時要能依照選擇模型靈活的切換參數,大致架構如下:

  1. 依照選擇的 timm model 導入 backbone 並提取最後兩層。
  2. backbone 輸出先經過 AvgPool 再經過 Mlp 層輸出預測目標類別。
  3. backbone 輸出第一層用於預測 Yoga-82 的第一層以及第二層標注(6, 20 個類別),backbone 輸出第二層用於預測 Yoga-82 的第三層(82個類別),之所以會這樣設計是依照經驗以及快速測試的結果,大家也可以靈活更動設計,如只使用backbone最後一層來預測 Yoga-82的三層標注。

Source Code 可以參考: Link

import logging
from typing import Callable, List, Optional

import timm
import torch
import torch.nn as nn
from torch import Tensor

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class Mlp(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
act_layer: Callable[..., nn.Module] = nn.GELU,
drop: float = 0.0,
bias: bool = True,
) -> None:
super().__init__()
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
self.drop = nn.Dropout(drop)

def forward(self, x: Tensor) -> Tensor:
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
return x


class TimmModelWrapper(nn.Module):
model_layer_config = {
"convnext_base.clip_laion2b_augreg_ft_in12k_in1k_384": [512, 512, 1024],
"convnext_small.in12k_ft_in1k_384": [384, 384, 768],
"coatnet_rmlp_2_rw_384.sw_in12k_ft_in1k": [512, 512, 1024],
}

def __init__(
self,
timm_model: str = "convnext_base.clip_laion2b_augreg_ft_in12k_in1k_384",
multiclassifier: List[int] = [6, 20, 82],
drop_rate: float = 0.2,
drop_path_rate: float = 0.5,
pretrained: bool = True,
**kwargs,
):
super().__init__()
if timm_model == "coatnet_rmlp_2_rw_384.sw_in12k_ft_in1k":
logger.warning(f"Input image size have to be 384 * 384")

self.backbone = timm.create_model(
timm_model,
pretrained=pretrained,
drop_rate=drop_rate,
drop_path_rate=drop_path_rate,
features_only=True,
out_indices=[-2, -1],
**kwargs,
)
self.adaptive_pooling = nn.AdaptiveAvgPool2d((1, 1))

self.multiclassifier = nn.ModuleList([])
layer_size = TimmModelWrapper.model_layer_config.get(timm_model)

for num_class, size in zip(multiclassifier, layer_size):
self.multiclassifier.append(
Mlp(size, out_features=num_class, drop=drop_rate)
)

num = self.count_parameters()
logger.info(f"Use bakcbone: {timm_model}, Total parameters: {num}")

def count_parameters(
self,
) -> int:
"""Counts the number of trainable parameters in a PyTorch model.

Args:
model (nn.Module): The model to count parameters for.

Returns:
int: The number of trainable parameters.
"""
num = sum(p.numel() for p in self.parameters() if p.requires_grad)
num = "{:.2f}M".format(num / 1_000_000)
return num

def forward(self, x):
batch = x.shape[0]
stack_layers = self.backbone(x)
outputs: List[torch.Tensor] = []

# Determine the index at which to switch from stack_layers[0] to stack_layers[1]
switch_index = len(self.multiclassifier) - len(stack_layers) + 1

for i, layer in enumerate(self.multiclassifier):
# Choose the tensor from stack_layers based on the current index
tensor_index = 0 if i < switch_index else 1
tensor = self.adaptive_pooling(stack_layers[tensor_index]).reshape(
batch, -1
)
tensor = layer(tensor)
outputs.append(tensor)
return outputs[0], outputs[1], outputs[2]

測試一下輸出格式是否符合預期

model = TimmModelWrapper()
output = model(torch.randn(1, 3, 384, 384))

#torch.Size([1, 6])
print(output[0].shape)
#torch.Size([1, 20])
print(output[1].shape)
#torch.Size([1, 82])
print(output[2].shape)

DINO 原理概述

如果只對實踐步驟感興趣可以直接跳過這段~

DINO指的是一種 Self-supervised Learning + Knowledge Distillation 的學習方式,並不是特定的模型架構,原作中使用的模型是ViT的架構,作者們也有測試 ResNet-50,但是 DINO 在 Transformer-based model 有較顯著的效果。

DINO pre-trained 的 small ViT 架構接上 k-NN classifiers 可以在 Imagenet-1k 的資料集上達到 top-1 78.3% 的精度 ,如果加上 supervised fine-tuning ,最終精度可以達到達到 81.5% ,比沒有 DINO 的版本提升了 1.6%。乍看之下好像不是特別厲害,但細想在不需要利用任何標註的情況下就能在 Imagenet 達到 78.3% 的精度,我是覺得還滿有潛力的,除此之外 DINO 訓練的模型自帶 object segmentations 的能力,也是 classification supervised learning 上不會直接獲得的特徵。

source

DINO 的架構如下:

  1. 類 Teacher-Student-Network:與Teacher-Student-Network不同的是, teacher 模型並不是預訓練模型、架構跟 student 模型共享、不參與 back-propagation、由 student 模型參數的 exponential moving average (EMA) 更新(給予近期參數更高權重的更新方法)。
  2. 單一輸入但是使用不同(隨機)的 Data Augmentation。
  3. teacher 模型使用 centering + sharpening 來提升訓練穩定度,避免 teacher 模型輸出塌陷。

可以參考下方實踐(部分步驟簡化,但邏輯一致)

import torch
import torch.nn as nn
import torch.optim as optim

def softmax(x, temperature, dim):
scaled_x = x / temperature
exp_x = torch.exp(scaled_x - torch.max(scaled_x, dim=dim, keepdim=True).values)
return exp_x / exp_x.sum(dim=dim, keepdim=True)

class SimpleModel(nn.Module):
def __init__(self, feature_dim, num_classes):
super(SimpleModel, self).__init__()
self.fc = nn.Linear(feature_dim, num_classes)

def forward(self, x):
return self.fc(x)

# Hyperparameters
num_classes = 10
feature_dim = 5
batch_size = 4
tps = 0.1 # Temperature for student's softmax
tpt = 0.1 # Temperature for teacher's softmax (sharpening)
l = 0.9 # Mixing coefficient for teacher update
m = 0.9 # Momentum for centering update
C = torch.zeros(num_classes) # Center vector

# Initialize models
student = SimpleModel(feature_dim, num_classes)
teacher = SimpleModel(feature_dim, num_classes)
teacher.load_state_dict(student.state_dict()) # Start with the same parameters

# Optimizer for the student
optimizer = optim.SGD(student.parameters(), lr=0.01)

# Dummy data loader
data_loader = [torch.randn(batch_size, feature_dim) for _ in range(10)] # Random data

# Training loop
for x in data_loader:
optimizer.zero_grad()

# Simulate two different augmentations
x1, x2 = x + torch.randn_like(x) * 0.1, x + torch.randn_like(x) * 0.1

# Pass through the models
s1, s2 = student(x1), student(x2)
t1, t2 = teacher(x1), teacher(x2)

# Compute the loss
def H(t, s):
t = t.detach() # Detach teacher's output
s = softmax(s, tps, dim=1)
t = softmax(t - C, tpt, dim=1)
return -(t * torch.log(s)).sum(dim=1).mean()

loss = (H(t1, s2) + H(t2, s1)) / 2
loss.backward()
optimizer.step()

# Update teacher by moving average
with torch.no_grad():
for student_param, teacher_param in zip(student.parameters(), teacher.parameters()):
teacher_param.data = l * teacher_param.data + (1 - l) * student_param.data

# Update center
C = m * C + (1 - m) * torch.cat([t1, t2], dim=0).mean(dim=0)

print("Loss:", loss.item())

DINOv2 進一步提升 Self-supervised Pre-training 的設計,並且清理預訓練資料集,提升品質與多樣性(不管是什麼任務,Data 一定是影響最大的,模型、訓練方法其次),不過後續研究指出 DINOv2 學到的特徵相較於 DINOv1 反而有更多瑕疵,下面會再解釋。

對 DINOv2 提升細節有興趣的讀者,可以參考原文第四五章節,已經寫的滿簡略易讀,Meta也有釋出Demo,大家可以玩玩看: Demo

DINOv2 paper: https://arxiv.org/pdf/2304.07193

source

DINO 系列第三篇 VISION TRANSFORMERS NEED REGISTERS (Link),發現除了 DINOv1 外,其餘 Transformer-based Model (包含 DINOv2)都有 Attention Map Artifact,可以理解為 Attention Map 上的瑕疵,專注在主體以外的部分。

source

藉由分析結果,作者們提出這種現象主要是因為:

  1. 缺乏 Spatial Information :

證明點 - 使用 Outlier token 預測位置時精度大幅降低。

2. 受到 Global Information 特徵影響 :

證明點 - 在圖像分類任務上 Outlier token 表現 比 Normal token 好,但這反而會讓模型整體的通用性降低。

文中提出的解決方法也很簡單,就是增加幾個可訓練的 Register Tokens,用來儲存 Global Information 的特徵,但不管是在 訓練或是預測階段,只使用既有的 Input Patch Tokens 以及 CLS Token。

source

觀看 Attention Map Visualization 以及 Benchmarking Metrics,確實可以發現有 Register 的模型比較專注,結果也有些微提升,不過!! DINOv1的結果還是最好的!!有時候簡單直覺的設計往往擁有最好的效果,像是CNN Backbone 層出不窮,但我們往往還是聽到 ResNet 或其改進版本。

source

DINO 實作

DINO pre-trained Backbone 可以由 PyTorch Hub 獲得 (reference),我們只需要去思考如何使用 Backbone 的輸出結構,我有嘗試過兩種 Head 的設計,最後採用 B 方案:

  1. Backbone 輸出四層,使用第0層的特徵來預測瑜伽標註第一層(6個類別),第2層的特徵來預測瑜伽標註第二層(20個類別)以及第3層的特徵來預測瑜伽標註第三層(82個類別)。
  2. 將 CLS token 與 feature map tokens的平均連結後輸出。
class _LinearClassifierWrapperB(nn.Module):
def __init__(
self,
*,
backbone: nn.Module,
multiclassifier: List[int] = [],
):
super().__init__()
self.backbone = backbone
out_dimension = self.backbone.blocks[-1].mlp.fc2.out_features
self.multiclassifier = nn.ModuleList(
[
Mlp(out_dimension * 2, out_features=num_class, drop=0.3)
for num_class in multiclassifier
]
)

def forward(self, x):
x = self.backbone.get_intermediate_layers(x, n=4, return_class_token=True)
linear_input_6 = torch.cat(
[
x[0][1],
x[0][0].mean(dim=1),
],
dim=1,
)
linear_input_20 = torch.cat(
[
x[2][1],
x[2][0].mean(dim=1),
],
dim=1,
)
linear_input_82 = torch.cat(
[
x[3][1],
x[3][0].mean(dim=1),
],
dim=1,
)
linear_input = [linear_input_6, linear_input_20, linear_input_82]
# fmt: on
output = []
for input_tensor, layer in zip(linear_input, self.multiclassifier):
output.append(layer(input_tensor))
return output

接下來來完成整個模型的搭建

def _make_dinov2_linear_classifier(
*,
arch_name: str = "vit_base",
layers: int = 4,
pretrained: bool = True,
weights: Union[Weights, str] = Weights.IMAGENET1K,
multiclassifier: List[int] = [],
**kwargs,
):
if arch_name == "dino2_vit_base":
backbone_name = "dinov2_vitb14"
if arch_name == "dino2_vit_small":
backbone_name = "dinov2_vits14"
logger.info(f"Activate dinov2 with backbone: {backbone_name}")
backbone = torch.hub.load("facebookresearch/dinov2", backbone_name, **kwargs)

return _LinearClassifierWrapperB(backbone=backbone, multiclassifier=multiclassifier)


def dinov2_vitb14_lc(
*,
layers: int = 4,
pretrained: bool = True,
weights: Union[Weights, str] = Weights.IMAGENET1K,
multiclassifier=[],
version="vit_base",
**kwargs,
):
"""
Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-B/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k.
"""
return _make_dinov2_linear_classifier(
arch_name=version,
layers=layers,
pretrained=pretrained,
weights=weights,
multiclassifier=multiclassifier,
**kwargs,
)

model = dinov2_vitb14_lc(
multiclassifier=[6, 20, 82], pretrained=False, version="dino2_vit_small"
)

Source Code 可以參考: Link

訓練以及預測

訓練

訓練原始碼有點冗長,這邊就不直接貼上,完成原始碼可以直接看 Training Script : Link

預測

我將模型打包成YogaClassifier,可以直接使用 DINOv2 或是 ConvNeXt 的 Backbone,訓練好的權重也能直接從 google drive 下載,詳細請參考 README

from yogahub.models import YogaClassifier
from PIL import Image
import numpy as np

# Initialize the model
model = YogaClassifier(backbone="dino2_vit_base", pretrained="weight/classify/dinov2_vitb14.pth")
model = YogaClassifier(backbone="convnext_small.in12k_ft_in1k_384", pretrained="weight/classify/convnext_small.in12k_ft_in1k_384.pth")

# Example: Predict the yoga pose in an image
output = model.predict("example/test.png", convert_to_chinese=True)

# Sample output:
#{
# "Target":"戰士二式(Virabhadrasana Ii)",
# "PotentialCandidate":[
# "舞王式(Natarajasana)",
# "單腳向上延展式(Urdhva Prasarita Eka Padasana)",
# "單腿站立伸展式(Utthita Padangusthasana)",
# "戰士三式(Virabhadrasana Iii)"
# ],
# "Gesture":"站立"
# }

在模型最終精度表現上,兩個架構差距不大,但是 dino2_vit_base 的參數量是93.83M,convnext_small 只有50.41M。由於訓練的時候 DINOv2 registers 版本還沒釋出,因此只有沒 registers 的版本。

Accuracy
dino2_vit_base: 0.9774, 0.9679, 0.9528 / class_6, class_20, class_82
convnext_small: 0.9849, 0.9719, 0.9518 / class_6, class_20, class_82

總結

這篇使用ConvNeXt 與 DINOv2 為範本,介紹如何快速搭圖像分類模型並獲得不錯的結果,文內也有簡單講解不同架構背後的原理,其他背景知識則可以參考先前發表的文章或是附上的參考資料。

後記

此篇是Yoga Assistant系列的首篇,後續會介紹的主題包含 模型壓縮、模型封裝 (Docker) 、模型部署(AWS/Serving/Deployment)、構建 RAG pipeline (LLMs) ,整合分類模型打造 Yoga Assistant Chatbot,相關的技術有許多不同的應用場景。由於擔心文章太冗長,部分細節斟酌省略,如果有任何想法或是建議,也歡迎直接跟我聊聊交流。

我的 LinkedIn : Jeff Yang

--

--

楊哲寧 (Jeff Yang)

目前在 Cinnamon AI Global 擔任 AI Director,研究、開發項目著重於智慧型文件處理(IDP)以及檢索增強生成(RAG),相關服務在日常生活中也有許多有趣且實用的應用場景,隨手筆記並期望能與大家多交流。