[Hands-On] ViT를 활용한 헤드 기반 이미지 분류

Hugman Sangkeun Jung
23 min readApr 14, 2024

--

교육용 목적으로 작성된 코드입니다.

(영어버젼의 포스트는 링크에서 찾아볼 수 있습니다.)

이 포스트는 Head기반 분류 기술 실습 시리즈 중 2번째 실습입니다. 이전 포스트에서 우리는 텍스트에서의 헤드기반 분류를 살펴보았습니다.

이번 포스트에서는 이미지에서의 헤드기반 분류를 자세하게 살펴보겠습니다. 이 실습에서는 Vision Transformer (ViT)를 이용합니다. 먼저 Hugging Face의 transformers 라이브러리를 사용하여 선학습된 모델을 다운로드 받고, 그 모델에 헤드기반 분류가 어떻게 적용되는지 살펴보겠습니다. 데이터셋으로는 사과, 체리 등의 과일 이미지를 포함하는 과일 데이터셋을 사용할 예정입니다.

ViT는 무엇인가요?

Vision Transformer (ViT)는 NLP에서 널리 성공한 트랜스포머의 원리를 이미지 분류 작업에 성공적으로 적용한 신경망 구조입니다. 픽셀이나 컨볼루셔널 특징 형태로 이미지를 처리하는 대신, ViT는 이미지를 일련의 패치로 취급하고 이러한 시퀀스에 대해 분류 작업을 위해 트랜스포머 모델을 적용합니다. BERT를 충분히 이해하셨다면, ViT도 바로 이해하실 수 있습니다. ViT는 이미지 버젼의 BERT거든요.

Head 기반 분류란 무엇인가요?

Head 기반 분류는 사전 훈련된 Vision Transformer (ViT) 모델에 하나 또는 여러 개의 “헤드”(신경망 층 또는 층의 집합)를 추가하여 특정 시각적 작업을 수행하는 접근 방식입니다. 예를 들어, 이미지 내의 객체를 분류하는 작업이 그것입니다. ViT에서 “헤드”는 각 이미지 패치에서 추출된 특성을 통합하고, 이를 기반으로 분류 작업에 대한 예측을 출력하기 위해 최적화됩니다. 이렇게 함으로써, 충분한 데이터에 기반하여 ViT가 학습한 풍부한 시각적 표현을 활용하여 네트워크가 해결하려는 문제 — 여기서는 객체 분류 — 를 효과적으로 수행할 수 있습니다.

ViT의 헤드 기반 접근 방식을 활용함으로써, 우리는 사전 훈련된 모델을 특정 시각적 작업에 적합하게 상대적으로 작은 데이터셋으로 미세 조정할 수 있습니다.

ViT와 Head 기반 분류

Vision Transformer (ViT) 모델에서는 이미지의 각 부분을 패치로 나누고 이 패치들을 Transformer의 입력으로 사용합니다. 이 과정에서 각 이미지 패치는 별도의 토큰으로 처리되며, 시각적 작업에서 중요한 역할을 합니다. 여기서는 이미지 패치와 분류 헤드 간의 상호작용을 자세하게 단계별로 설명하겠습니다.

  1. 사전 훈련된 ViT 초기화
    대규모 이미지 데이터셋으로부터 사전 훈련된 ViT 모델을 준비합니다. 이 모델은 이미 객체 인식, 시각적 맥락 파악 등 시각적 정보 처리에 필요한 기본적인 능력을 가지고 있습니다.
  2. 패치 및 클래스 토큰 추가
    각 입력 이미지를 여러 개의 패치로 나누고, 이 패치들을 Transformer의 입력으로 변환합니다. 각 패치는 독립적인 토큰으로 처리됩니다. 여기에 더해, 이미지의 전체적인 정보를 대표할 수 있는 클래스 토큰(유사 [CLS] 토큰)이 각 이미지 시퀀스의 시작 부분에 추가됩니다. 이 클래스 토큰은 Transformer의 모든 레이어를 통과하며, 각 레이어에서의 자기 주의 메커니즘을 통해 패치들의 정보와 통합되어 이미지 전체의 맥락을 학습하는 데 중요한 역할을 합니다. 위치 토큰 또한 각 패치와 함께 추가되어 Transformer 내에서의 패치 위치 정보를 유지하며, 정확한 공간적 맥락을 학습하는 데 도움을 줍니다.
  3. 분류 헤드 설계
    ViT의 최상단 레이어에서 패치 기반의 정보를 종합하여 이미지 전체를 대표하는 벡터를 얻어냅니다. 이 벡터를 분류 헤드에 연결합니다. 분류 헤드는 일반적으로 소프트맥스 레이어를 포함한 간단한 신경망으로 구성되어, 이미지의 시각적 표현을 특정 클래스 레이블로 매핑합니다.
  4. 미세 조정(Fine-tuning)
    분류 작업에 특화된 데이터셋에서 ViT와 분류 헤드를 결합한 모델을 미세 조정합니다. 이 단계에서는 특정 사용 사례에 최적화되도록 ViT의 파라미터와 분류 헤드를 함께 조정합니다. 이를 통해 모델은 더욱 정확하게 이미지를 분류할 수 있는 능력을 개발합니다.
Head-based classification with ViT

구현 내용 미리보기

이번 실습에서는 다음과 같은 주요 단계를 거치게 됩니다.

  1. 데이터셋 준비 및 모델 초기화: 우리는 여러 종류의 과일 이미지를 포함하는 데이터셋을 사용할 것입니다. 이 데이터셋을 로드하고, Hugging Face의 transformers 라이브러리를 사용하여 사전 훈련된 Vision Transformer (ViT) 모델을 준비합니다.
  2. 패치 및 클래스 토큰 추가: ViT는 이미지를 여러 패치로 나누고, 각 패치를 독립적인 토큰으로 처리합니다. 또한, 이미지 전체의 정보를 대표할 수 있는 클래스 토큰을 추가하여 모델이 이미지의 전체적인 맥락을 이해할 수 있도록 합니다.
  3. 미세 조정 및 평가: 사전 훈련된 ViT 모델과 추가된 분류 헤드를 사용하여 데이터셋을 미세 조정합니다. 훈련 과정은 배치별로 데이터를 모델에 공급하고, 에러를 역전파하여 모델의 가중치를 업데이트합니다. 검증 세트를 사용하여 모델 성능을 주기적으로 평가하며, 정확도, 정밀도, 재현율, F1 점수 등의 메트릭을 사용하여 최종 모델을 평가합니다.
  4. 결과 시각화 및 분석: 훈련된 모델의 성능을 Confusion Matrix와 같은 시각적 도구를 사용하여 분석하고 시각화합니다. 이를 통해 모델이 각 과일 카테고리에 대해 얼마나 잘 분류하는지, 어떤 카테고리에서 오류가 발생하는지 등을 파악할 수 있습니다.

환경 준비하기

먼저 데이터를 다운로드 받고 준비합니다. 해당 데이터셋은 제가 미리 만들어 dropbox를 통해 zip 파일로 공유하고 있으니 받아서 아래와 같이 처리하면 됩니다.

!rm data.zip
!rm -r data
!wget -qq 'https://www.dropbox.com/scl/fi/mf806inzv0x6abbb2su0k/fruit-dataset.zip?rlkey=gk7s9d14m6k9o1ru4whrgtwke&dl=1' -O data.zip
!unzip -qq data.zip -d data/
!echo "Number of Images : $(find ./data -type f | wc -l)"

필요한 라이브러리들을 준비합니다.

# Import necessary libraries
from transformers import ViTForImageClassification, ViTFeatureExtractor
from transformers import Trainer, TrainingArguments
from datasets import load_dataset
import torch
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import train_test_split

# Load the dataset
dataset = load_dataset("./data")

실험의 재현을 위한 준비를 합니다.

import os
import random

# Function to set the seed for reproducibility
def set_seed(seed_value=42):
"""Set seed for reproducibility."""
np.random.seed(seed_value)
torch.manual_seed(seed_value)
torch.cuda.manual_seed(seed_value)
torch.cuda.manual_seed_all(seed_value) # if you are using multi-GPU.
random.seed(seed_value)
os.environ['PYTHONHASHSEED'] = str(seed_value)

# The below two lines are for deterministic algorithm behavior in CUDA
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Set the seed
set_seed()

이미지 전처리

먼저 데이터를 다운로드 받고 몇 가지 전처리를 합니다.

# Initialize the feature extractor
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')

# Preprocess the dataset
def preprocess_images(examples):
images = [feature_extractor(image.convert("RGB")) for image in examples["image"]]
examples['pixel_values'] = [image['pixel_values'][0] for image in images]
return examples

dataset = dataset.map(preprocess_images, batched=True)
dataset.set_format(type='torch', columns=['image', 'pixel_values', 'label'])

# Split the dataset into training and validation sets using Hugging Face's built-in method
train_test_split = dataset["train"].train_test_split(test_size=0.2)

# Assign the split datasets
train_dataset = train_test_split["train"]
val_dataset = train_test_split["test"]

위 코드에선 크게 아래와 같이 3가지 일을 수행합니다.

  • 피처 추출기 초기화 : ViTFeatureExtractor.from_pretrained 메소드를 사용하여 사전 훈련된 'google/vit-base-patch16-224-in21k' 모델을 기반으로 피처 추출기를 초기화합니다. 이 피처 추출기는 이미지의 크기 조정, 정규화 및 필요한 형식 변환을 자동으로 처리합니다. 자연어처리의 tokenizer처럼 Image Model 에 특화된 전처리 부분을 담당해서 처리해주기 때문에 보통 모델마다 고유한 Feature Extractor를 다운로드 받을 수 있습니다.
  • 데이터셋 사전 처리:preprocess_images 함수를 정의하여 각 이미지를 RGB로 변환하고, feature_extractor를 통해 이미지를 전처리합니다. 전처리된 이미지는 픽셀 값과 레이블 정보를 포함하도록 데이터셋에 추가됩니다. dataset.map 함수를 사용하여 전체 데이터셋에 이 사전 처리를 적용하고, torch 타입으로 데이터 형식을 설정합니다.
  • 훈련 및 검증 세트 분할: 데이터셋을 훈련 세트와 검증 세트로 분할합니다. 여기서는 훈련 데이터의 20%를 검증 데이터로 사용합니다. Hugging Face의 내장 메서드 train_test_split을 사용하여 이 분할을 수행합니다.

이렇게 전처리한 데이터를 살펴보면 아래와 같이 나타납니다.

dataset
DatasetDict({
train: Dataset({
features: ['image', 'label', 'pixel_values'],
num_rows: 1600
})
test: Dataset({
features: ['image', 'label', 'pixel_values'],
num_rows: 400
})
})

Model 준비

# Ensure you're accessing the 'train' split (or another specific split) to get the features
num_labels = len(dataset['train'].features['label'].names)

# Load pre-trained ViT model
model = ViTForImageClassification\
.from_pretrained('google/vit-base-patch16-224-in21k',
num_labels=num_labels # <--!!!
)

Vision Transformers에서의 Head 기반 분류

Vision Transformers에서 ‘head 기반 분류’는 특정 분류 작업에 맞춤화된 ‘헤드’를 사전 훈련된 트랜스포머 모델에 추가하는 것을 말합니다. ViT 모델은 NLP처럼 입력 데이터(이 경우, 이미지 패치)를 여러 트랜스포머 레이어를 통해 처리합니다. 이 레이어의 출력은 그 후 분류 헤드로 전달됩니다.

분류 헤드는 일반적으로 완전 연결 신경망 레이어(보통 MLP)로, 트랜스포머에 의해 추출된 고차원 특성을 데이터셋의 클래스 수에 매핑합니다. 이미지 분류 작업의 경우, 이 헤드는 트랜스포머의 출력 중 [CLS] 토큰(전체 시퀀스에서 정보를 집약하는 데 사용되는 특수 토큰)에 해당하는 출력을 받아 클래스별 확률 분포를 생성합니다.

여기서 ‘num_labels’ 매개변수는 매우 중요합니다. 이 매개변수는 ViT 모델의 분류 헤드를 우리 과일 데이터셋의 클래스 수에 맞게 조정합니다. 이를 통해 모델의 출력 레이어가 데이터셋의 특정 클래스에 대한 확률 분포를 생성하도록 올바르게 크기를 조정하게 됩니다.

Model 훈련

훈련 인자 정의

# Define training arguments
training_args = TrainingArguments(
output_dir="./vit_fruit_classification",
evaluation_strategy="epoch",
learning_rate=2e-4,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=3,
weight_decay=0.01,

logging_dir='./logs',
)

# Function to compute metrics
def compute_metrics(eval_pred):
logits, labels = eval_pred
predictions = np.argmax(logits, axis=-1)
precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='weighted')
acc = accuracy_score(labels, predictions)
return {"accuracy": acc, "f1": f1, "precision": precision, "recall": recall}

# Initialize Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
compute_metrics=compute_metrics,
)

훈련

trainer.train()

우리가 다루고 있는 과일데이터셋은 쉬운 task 이기 때문에, 1 epoch 시에 이미 모두 맞추고 있음을 알 수 있습니다.

Model 저장

# Save the fine-tuned model
model.save_pretrained("./vit_fruit_cls")
# Save the feature extractor
feature_extractor.save_pretrained('./vit_fruit_cls')

향후 평가 및 예측에 사용할 수 있게 저장해둡니다. BERT 때는 Tokenizer를 저장하게 되는데, ViT에서는 feature extractor를 저장하는 것에 주목하세요.

예측 및 평가

이제 테스트셋에 대해 평가를 진행해 보겠습니다. 이를 위해 pipeline을 만들어 두고, 이를 이용해 쉽게 데이터셋 전체에 대해 평가를 진행할 수 있습니다.

from transformers import pipeline

# Load the pipeline with the model and feature extractor
image_classifier = pipeline('image-classification',
model='./vit_fruit_cls',
feature_extractor='./vit_fruit_cls')
from PIL import Image
import torchvision.transforms.functional as TF
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from tqdm import tqdm

true_labels = []
pred_labels = []
scores = [] # List to store probability values

for item in tqdm(dataset['test']):
# Convert the PyTorch tensor to a PIL Image
t_img = item['image'].permute(2, 0, 1)
image = TF.to_pil_image(t_img)

# Use the pipeline to predict the class of each image
pred = image_classifier(image)
pred_label = pred[0]['label']

# Extract score for the predicted label
score = pred[0]['score']

# Convert predicted label to the corresponding index
pred_label_idx = int(pred_label.split('_')[-1])

true_labels.append(item['label'])
pred_labels.append(pred_label_idx)
scores.append(score) # Append score to the list

# Convert lists to numpy arrays for metric calculation
true_labels = np.array(true_labels)
pred_labels = np.array(pred_labels)

# Calculate accuracy, precision, recall, and F1-score
accuracy = accuracy_score(true_labels, pred_labels)
precision = precision_score(true_labels, pred_labels, average='macro') # Change 'macro' to 'micro' or 'weighted' as needed
recall = recall_score(true_labels, pred_labels, average='macro') # Change 'macro' to 'micro' or 'weighted' as needed
f1 = f1_score(true_labels, pred_labels, average='macro') # Change 'macro' to 'micro' or 'weighted' as needed

# Print the metrics
print(f'Accuracy: {accuracy:.2f}')
print(f'Precision: {precision:.2f}')
print(f'Recall: {recall:.2f}')
print(f'F1-score: {f1:.2f}')
Accuracy: 1.00
Precision: 1.00
Recall: 1.00
F1-score: 1.00

위 코드에서 보면 이미지 파일을 PIL 형태로 바꿔주는 것과 예측결과를 해석해주는 파트 말고는 별 다를게 없습니다. 그외의 작업들은 pipeline에서 모두 처리하기 때문이죠.

Confusion Matrix

import seaborn as sns
from sklearn.metrics import confusion_matrix

# Access the classes attribute to obtain class labels
classes = dataset['train'].features['label'].names

# Calculate the confusion matrix
conf_matrix = confusion_matrix(true_labels, pred_labels)

# Create a heatmap to visualize the confusion matrix
plt.figure(figsize=(8, 6))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', cbar=False,
xticklabels=range(len(classes)), yticklabels=range(len(classes)))

plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.xticks(ticks=np.arange(len(classes)) + 0.5, labels=classes)
plt.yticks(ticks=np.arange(len(classes)) + 0.5, labels=classes)
plt.show()

confusion matrix를 통해 살펴봐도, 모든 과일들에 대해 라벨이 잘 예측되고 있음을 알 수 있습니다.

예측 결과 시각화

우리 모델은 과일데이터 셋에 대해서 모두 잘 맞추기 때문에 추가 분석이 필요 없습니다만, 보통 어떤 이미지에 대해 어떻게 예측하는지를 아래처럼 시각화 해보면 추가적인 Insight를 얻을 수 있습니다.

import matplotlib.pyplot as plt
import random

# Define a function to display predictions with class names
def display_predictions(dataset, true_labels, pred_labels, class_names, scores, num_samples=10):
fig, axs = plt.subplots(num_samples, 4, figsize=(12, 24))

for i in range(num_samples):
# Randomly select a sample from the dataset
index = random.randint(0, len(dataset) - 1)
image = dataset[index]['image']
true_label = true_labels[index]
pred_label = pred_labels[index]
class_name = class_names[pred_label] # Get class name from class_names list
score = scores[index]

# Display the image
axs[i, 0].imshow(image) # Convert to HWC format
axs[i, 0].axis('off')

# Display the predicted label with class name
axs[i, 1].text(0.5, 0.5, f'Predicted: {class_name}', fontsize=12, ha='center')
axs[i, 1].axis('off')

# Display the true label with class name
axs[i, 2].text(0.5, 0.5, f'True: {class_names[true_label]}', fontsize=12, ha='center')
axs[i, 2].axis('off')

# Display the score
axs[i, 3].text(0.5, 0.5, f'Score: {score:.2f}', fontsize=12, ha='center')
axs[i, 3].axis('off')

plt.tight_layout()
plt.show()

# Assuming you have true_labels, pred_labels, probabilities, and class_names
display_predictions(dataset['test'], true_labels, pred_labels, classes, scores)

결론

이 글을 통해 우리는 Vision Transformer (ViT)를 활용한 head-based 이미지 분류 기술을 살펴보고, 이를 시각적 객체 인식의 특정 문제인 과일 이미지 분류에 적용해 보았습니다. 특히, ViT 모델에서 이미지를 여러 패치로 나누어 각각을 독립적인 토큰으로 처리하고, 전체 이미지 정보를 대표하는 클래스 토큰을 사용하여 분류 헤드를 부착하는 방법을 배웠습니다. 이 기술을 활용하여 다양한 과일 이미지를 포함하는 데이터셋을 대상으로 모델을 미세 조정하고 평가하였으며, 결과를 시각화하여 모델의 성능을 분석해 보았습니다.

본 실습 코드는 Colab에서 직접 다운로드 받거나 실행해볼 수 있습니다.

--

--

Hugman Sangkeun Jung

Hugman Sangkeun Jung is a professor at Chungnam National University, with expertise in AI, machine learning, NLP, and medical decision support.