[Hands-On] CLIP을 활용한 프롬프트 기반 이미지 분류

Hugman Sangkeun Jung
31 min readJul 13, 2024

--

(You can find the English version of the post at this link.)

이전 글에서 우리는 텍스트 분류를 위한 프롬프트 기반 접근법에 대해 살펴보았습니다.

이번 포스트에서는 그 개념을 확장하여 이미지 분류 작업에 적용해 보겠습니다. 특히 OpenAI의 CLIP(Contrastive Language–Image Pre-training) 모델을 활용하여 프롬프트 기반 이미지 분류를 수행하는 실제 코드를 구현하고 그 결과를 분석해 보도록 하겠습니다.

프롬프트 기반 이미지 분류란 무엇인가?

프롬프트 기반 이미지 분류는 CLIP(Contrastive Language-Image Pre-training)과 같은 비전-언어 모델들을 활용하여 이미지를 분류하는 혁신적인 방법입니다. 이러한 모델들은 대규모의 이미지-텍스트 쌍 데이터를 학습하여 시각적 정보와 언어적 설명 사이의 깊은 연관성을 이해할 수 있게 됩니다.

CLIP을 시작으로, ALIGN(Aligning text and images), DALL-E, Imagen, Stable Diffusion 등의 모델들이 등장하면서 이미지와 텍스트의 통합적 이해 능력이 크게 향상되었습니다. 이러한 모델들은 단순히 이미지를 분류하는 것을 넘어서, 이미지의 내용을 자연어로 설명하거나 반대로 텍스트 설명을 바탕으로 이미지를 생성하는 등의 다양한 작업을 수행할 수 있습니다.

프롬프트 기반 이미지 분류의 주요 특징은 다음과 같습니다:

  1. 텍스트 프롬프트 활용: 이미지 클래스를 설명하는 자연어 문장을 프롬프트로 사용합니다. 이는 단순한 레이블을 넘어서 더 풍부하고 상세한 정보를 제공할 수 있습니다.
  2. 이미지-텍스트 alignment: 모델은 입력된 이미지와 다양한 텍스트 프롬프트 간의 유사도를 계산합니다. 이 과정에서 이미지의 시각적 특징과 텍스트의 의미론적 내용이 고차원 공간에서 매칭됩니다.
  3. Zero-shot 및 Few-shot 학습: 특정 태스크에 대한 추가 학습 없이도 다양한 분류 작업을 수행할 수 있습니다. 또한, 소수의 예제만으로도 새로운 개념을 빠르게 학습할 수 있는 능력을 가지고 있습니다.
  4. 유연성과 확장성: 새로운 클래스를 추가하거나 변경할 때 모델을 재학습할 필요 없이 프롬프트만 수정하면 됩니다. 이는 실제 응용에서 매우 유용한 특징입니다.
  5. 멀티모달 이해: 이미지와 텍스트를 동시에 처리할 수 있어, 복잡한 개념이나 추상적인 아이디어도 표현하고 인식할 수 있습니다.
  6. 컨텍스트 인식: 단순히 객체를 인식하는 것을 넘어서, 이미지의 전체적인 맥락과 상황을 이해할 수 있습니다.
  7. 크로스 모달 전이 학습: 한 모달리티(예: 텍스트)에서 학습한 지식을 다른 모달리티(예: 이미지)에 적용할 수 있습니다.

이러한 접근 방식은 전통적인 지도 학습 기반의 이미지 분류 방법과는 달리, 언어의 풍부한 표현력을 활용하여 이미지를 더 유연하고 포괄적으로 이해할 수 있게 해줍니다. 또한, 계속해서 발전하는 대규모 언어 모델(LLM)과의 통합을 통해 더욱 강력한 이미지 이해 및 생성 능력을 보여주고 있습니다.

이제 이러한 기본 개념을 바탕으로 CLIP을 활용한 프롬프트 기반 이미지 분류를 구현해보겠습니다.

환경 준비하기

먼저, 필요한 라이브러리를 설치하고 가져옵니다.

!pip install -qq datasets

# Import necessary libraries
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import torch

# Check if CUDA is available and set the device accordingly
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

여기서 우리는 Hugging Face의 transformers 라이브러리를 사용하여 CLIP 모델과 프로세서를 불러옵니다. PIL(Python Imaging Library)은 이미지 처리를 위해 사용되며, PyTorch는 딥러닝 연산을 위해 필요합니다. CUDA가 사용 가능한 경우 GPU를 활용하여 연산 속도를 높일 수 있습니다.

데이터 준비

이번 실습에서는 과일 이미지 데이터셋을 사용합니다. 데이터셋을 다운로드하고 로드하는 과정은 다음과 같습니다:

!rm data.zip
!rm -r data
!wget -qq 'https://www.dropbox.com/scl/fi/mf806inzv0x6abbb2su0k/fruit-dataset.zip?rlkey=gk7s9d14m6k9o1ru4whrgtwke&dl=1' -O data.zip
!unzip -qq data.zip -d fruit-dataset/
!echo "Number of Images : $(find ./fruit-dataset -type f | wc -l)"
Number of Images : 2000
# Load the fruit dataset
from datasets import load_dataset
dataset = load_dataset("./fruit-dataset")

# Preview the dataset structure
print(dataset)

class_map = { k:v for k, v in enumerate(dataset['train'].features['label'].names) }
class_names = [v for k,v in class_map.items()]
print(class_map)
print(class_names)
DatasetDict({
train: Dataset({
features: ['image', 'label'],
num_rows: 1600
})
test: Dataset({
features: ['image', 'label'],
num_rows: 400
})
})
{0: 'apple', 1: 'asian pear', 2: 'banana', 3: 'cherry'}
['apple', 'asian pear', 'banana', 'cherry']

이 과정을 통해 우리는 과일 이미지 데이터셋을 다운로드하고, Hugging Face의 datasets 라이브러리를 사용하여 로드합니다. 데이터셋의 구조를 확인하고, 클래스 이름과 인덱스 간의 매핑을 생성합니다. 이는 후속 작업에서 예측 결과를 해석할 때 유용하게 사용됩니다.

CLIP 모델 및 프로세서 준비

다음으로 CLIP 모델과 프로세서를 로드합니다:

# Load CLIP model and processor
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)

여기서 우리는 OpenAI에서 제공하는 사전 학습된 CLIP 모델을 사용합니다. ‘clip-vit-base-patch32’는 Vision Transformer(ViT) 아키텍처를 기반으로 하는 CLIP 모델의 한 변형입니다.

참고로, CLIP 말고도 BLIP과 같은 CLIP의 후속 모델을 사용할 수도 있습니다. 간단히 정리하면, CLIP은 대규모 이미지-텍스트 쌍 데이터로 학습된 모델로, 제로샷 학습 능력이 뛰어나며 다양한 시각적 작업에서 범용성을 보입니다. BLIP은 CLIP의 개선 버전으로, 더 정교한 이미지-텍스트 관계 이해와 생성 능력을 갖추고 있으며, 특히 이미지 캡셔닝과 시각적 질의응답 작업에서 우수한 성능을 보입니다.

BLIP을 활용하려면 아래 코드처럼 수행하면 됩니다:

from transformers import BlipProcessor, BlipForConditionalGeneration

# Load BLIP model and processor
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)

BLIP 모델은 이미지 캡셔닝, 시각적 질의응답, 이미지-텍스트 검색 등 다양한 작업을 수행할 수 있습니다. 위의 예시는 이미지 캡셔닝을 위한 BLIP 모델을 로드하는 방법을 보여줍니다. BLIP은 CLIP보다 더 복잡한 이미지-텍스트 관계를 이해하고 생성할 수 있어, 더 상세하고 정확한 이미지 설명을 제공할 수 있습니다.

두 모델 모두 강력한 성능을 보이지만, 작업의 특성에 따라 적절한 모델을 선택하는 것이 중요합니다. CLIP은 제로샷 분류 작업에 강점이 있고, BLIP은 더 복잡한 이미지-텍스트 관계 이해가 필요한 작업에 적합합니다.

본 실습에서는 CLIP 을 중심으로 소개하겠습니다.

이미지 분류 함수 구현

이제 CLIP을 사용하여 이미지를 분류하는 핵심 함수를 구현해보겠습니다:

import os
def classify_image(image, align_texts):
# Check if the image is a file path
if isinstance(image, str) and os.path.isfile(image):
image = Image.open(image)
elif not isinstance(image, Image.Image):
# If the image is not a PIL Image and not a file path, raise an error
raise ValueError("The provided image must be a PIL Image or a file path to an image.")
# If 'image' is already a PIL Image, no conversion is needed

# Process the image with the processor
inputs = processor(text=align_texts, images=image, return_tensors="pt", padding=True)

# To see the actual text used for alignment, inspect the tokenized text
tokenized_text = inputs['input_ids']

# Decode the tokenized text to see the human-readable format
decoded_text = [processor.tokenizer.decode(ids) for ids in tokenized_text]

# Move inputs to the device
inputs = {k: v.to(device) for k, v in inputs.items()}

# Get model predictions
with torch.no_grad():
outputs = model(**inputs)
logits_per_image = outputs.logits_per_image # Image-to-text similarity scores
probs = logits_per_image.softmax(dim=1)
pred_text_idx = torch.argmax(probs, dim=1)

# Return the class name with the highest probability
return pred_text_idx.item(), decoded_text, probs

이 함수는 주어진 이미지와 텍스트 프롬프트들을 입력으로 받아, 이미지와 가장 잘 일치하는 텍스트 프롬프트의 인덱스와 각 프롬프트에 대한 확률을 반환합니다.

프롬프트 생성 및 분류 수행

먼저 프롬프르 기반의 이미지 분류 코드 구현의 전체 흐름을 설명하겠습니다.

  1. 텍스트 프롬프트 생성: 미리 정의된 query_templates를 사용하여 각 클래스에 대한 설명 텍스트를 생성합니다. 이 텍스트들은 이미지의 잠재적 내용을 설명하는 자연어 프롬프트 역할을 합니다. 예를 들어, "이것은 사과의 이미지입니다" 또는 "이 사진에는 고양이가 있습니다" 같은 문장들이 포함될 수 있습니다.
  2. 이미지 처리 및 정렬:
    -
    classify_image 함수는 입력이 유효한 이미지인지 또는 이미지 파일 경로인지 확인합니다. 경로인 경우 PIL 라이브러리를 사용하여 이미지를 로드합니다.
    - 텍스트 프롬프트와 이미지를 함께 처리하여 모델의 입력 텐서를 생성합니다. 이 과정에서 텍스트를 토큰화(모델이 처리할 수 있는 숫자 형식으로 변환)하고 이미지 데이터를 준비합니다.
  3. 모델 예측:
    - 처리된 입력을 모델에 제공하여 이미지와 각 텍스트 프롬프트 간의 유사도 점수를 계산합니다. 이 단계에서는 텍스트와 이미지 특징 간의 정렬을 평가할 수 있는 신경망 모델을 사용합니다.
    - 이 점수들에 소프트맥스 함수를 적용하여 각 프롬프트가 이미지와 일치할 확률을 도출합니다.
  4. 결과 생성:
    - 시스템은 가장 높은 확률 점수를 가진 프롬프트를 이미지의 가장 정확한 설명으로 식별하여 효과적으로 이미지를 분류합니다.
    - 또한, 이 함수는 사용된 프롬프트의 사람이 읽을 수 있는 버전과 확률 점수를 반환하여 모델의 의사 결정 과정에 대한 통찰을 제공합니다.

이제 위 내용을 구현해보죠.

query_templates = [
"This is an image of a {image}.",
"Here we see a {image}.",
"The photo depicts a {image}.",
"In this picture, there is a {image}.",
"This picture shows a {image}.",
"A {image} is present in this image.",
"You can see a {image} here.",
"This image represents a {image}.",
"A {image} is captured in this shot.",
"The object in this image is a {image}."
]

def get_n_query_texts(class_name, N):
query_texts = [template.format(image=class_name) for template in query_templates[:N]]
return query_texts

def make_query_texts(class_names, N):
class_and_texts_map = []

def prepend_a_or_an(word):
# List of vowels
vowels = 'aeiou'

# Check if the first letter of the word is a vowel
if word[0].lower() in vowels:
return f"an {word}"
else:
return f"a {word}"

for cls in class_names:
query_texts = get_n_query_texts(cls, N)
items = [ (cls, q) for q in query_texts ]
class_and_texts_map += items

return class_and_texts_map

def do_prompt_based_image_classification(image_path, N=1):
class_and_texts_map = make_query_texts(class_names, N)
query_texts = [x[1] for x in class_and_texts_map]
best_aligned_text_idx, actual_query_texts, probs = classify_image(image_path, query_texts)
return best_aligned_text_idx, actual_query_texts, class_and_texts_map, probs.softmax(-1)

이 함수들은 각 클래스에 대해 여러 개의 텍스트 프롬프트를 생성하고, 이를 사용하여 주어진 이미지를 분류합니다. N 파라미터를 조절하여 각 클래스당 사용할 프롬프트의 수를 지정할 수 있습니다.

간단히 샘플 이미지에 대해서 분류를 진행해 보겠습니다.

image_path = "./fruit-dataset/test/apple/apple_401.png"
best_aligned_text_idx, actual_query_texts, class_and_texts_map, probs = do_prompt_based_image_classification(image_path)
best_aligned_text_idx
0
actual_query_texts
['<|startoftext|>this is an image of a apple. <|endoftext|><|endoftext|>',
'<|startoftext|>this is an image of a asian pear. <|endoftext|>',
'<|startoftext|>this is an image of a banana. <|endoftext|><|endoftext|>',
'<|startoftext|>this is an image of a cherry. <|endoftext|><|endoftext|>']
class_and_texts_map
[('apple', 'This is an image of a apple.'),
('asian pear', 'This is an image of a asian pear.'),
('banana', 'This is an image of a banana.'),
('cherry', 'This is an image of a cherry.')]
probs
tensor([[0.4009, 0.1855, 0.1844, 0.2291]])

확률값을 보게 되면 apple 로 분류될 확률이 가장 높음을 알 수 있습니다.

결과 시각화

위의 결과를 실제 이미지와 함께 보기 좋게 출력해주는 함수를 구현해보죠. 아래 함수는 입력 이미지, 예측된 클래스, 그리고 해당 예측의 확률을 함께 표시합니다.

import matplotlib.pyplot as plt

def visualize_classification(image_path, best_align_text_idx, class_and_texts_map, probs):
# Load and display the image
image = Image.open(image_path)
plt.imshow(image)
plt.axis('off') # Hide the axis

probability = probs[0][best_align_text_idx]
class_label, text = class_and_texts_map[best_align_text_idx]

plt.title(f"Best aligned text: {text}\nPredicted Class: {class_label }\nProbability: {probability:.4f}", loc='left')

# Show the plot
plt.show()
image_path = "./fruit-dataset/test/apple/apple_401.png"
best_align_text_idx, actual_query_texts, class_and_texts_map, probs = do_prompt_based_image_classification(image_path)
visualize_classification(image_path, best_align_text_idx, class_and_texts_map, probs)

모델 평가

위에서 샘플 이미지 1장에 대한 작동과정을 설명했습니다. 이제 테스트 데이터셋을 사용하여 모델의 성능을 평가해보겠습니다. 평가의 속도를 위해서 우선 200개 샘플만 사용할 텐데, 더 많은 데이터를 활용하기를 원하면 아래 코드의 num_samples 를 조정하시면 됩니다.

Prompt 기반 분류 기법의 장법은 prompt 를 class 그룹 별로 다양화 함으로써 얼마든지 앙상블과 비슷한 효과를 얻을 수 있다는 것입니다. 우선 하나의 class 마다 하나의 query 만 부여해서 실험을 진행해보죠.

Single Query Template 을 이용한 분류

from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm
from random import sample

true_labels = []
predicted_labels = []

r_class_map = { v:k for k, v in class_map.items() }

# Generate a list of random indices based on the length of the dataset
num_samples = 200
random_indices = sample(range(len(dataset['test'])), num_samples)

#for item in tqdm(dataset['test']):
for item in tqdm(dataset['test'].select(random_indices)):
image = item['image']
true_label = item['label']
true_labels.append(true_label)

best_aligned_text_idx, actual_query_texts, class_and_texts_map, probs = do_prompt_based_image_classification(image)
pred_class_label = class_and_texts_map[best_aligned_text_idx][0]
predicted_labels.append(r_class_map[pred_class_label])
# Calculate metrics
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
accuracy = accuracy_score(true_labels, predicted_labels)
precision, recall, f1, _ = precision_recall_fscore_support(true_labels, predicted_labels, average='weighted')

print(f'Accuracy: {accuracy:.4f}')
print(f'Precision: {precision:.4f}')
print(f'Recall: {recall:.4f}')
print(f'F1-score: {f1:.4f}')

# Generate the confusion matrix
cm = confusion_matrix(true_labels, predicted_labels)

# Plot the confusion matrix
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.title('Confusion Matrix')
plt.show()
Accuracy: 0.9350
Precision: 0.9394
Recall: 0.9350
F1-score: 0.9347
Single query — Confusion Matrix of prompt-based classification results with CLIP

F1 score 가 0.9347이 나오는것을 확인할 수 있습니다. 별도의 fine-tuning없이 과일 이미지 인식에 관해서 약 0.93 정도를 얻을 수 있음을 보여줍니다.

Multiple Query Templates 를 이용한 분류

이제 하나의 클래스 그룹별로 더 다양한 Query templates를 활용하여 일종의 앙상블을 진행해보도록 하겠습니다. 실제 템플릿들은 아래와 같습니다.

class_and_texts_map
[('apple', 'This is an image of a apple.'),
('apple', 'Here we see a apple.'),
('apple', 'The photo depicts a apple.'),
('apple', 'In this picture, there is a apple.'),
('apple', 'This picture shows a apple.'),
('asian pear', 'This is an image of a asian pear.'),
('asian pear', 'Here we see a asian pear.'),
('asian pear', 'The photo depicts a asian pear.'),
('asian pear', 'In this picture, there is a asian pear.'),
('asian pear', 'This picture shows a asian pear.'),
('banana', 'This is an image of a banana.'),
('banana', 'Here we see a banana.'),
('banana', 'The photo depicts a banana.'),
('banana', 'In this picture, there is a banana.'),
('banana', 'This picture shows a banana.'),
('cherry', 'This is an image of a cherry.'),
('cherry', 'Here we see a cherry.'),
('cherry', 'The photo depicts a cherry.'),
('cherry', 'In this picture, there is a cherry.'),
('cherry', 'This picture shows a cherry.')]
def merge_multiple_query_result(class_and_texts_map, probs):
# Initialize a dictionary to hold the sum of probabilities for each class
class_probs = {}

# Iterate over class_and_texts_map and probs to aggregate probabilities
for idx, (class_name, _) in enumerate(class_and_texts_map):
if class_name not in class_probs:
class_probs[class_name] = 0
class_probs[class_name] += probs[0][idx].item() # Convert from tensor to float and add to the class's total probability

# Find the class with the highest aggregated probability
predicted_class = max(class_probs, key=class_probs.get)
return predicted_class, class_probs[predicted_class]

true_labels = []
predicted_labels = []

for item in tqdm(dataset['test'].select(random_indices)):
image = item['image']
true_label = item['label']
true_labels.append(true_label)

best_aligned_text_idx, actual_query_texts, class_and_texts_map, probs = do_prompt_based_image_classification(image, N=5)

pred_class_label, pred_merged_prob = merge_multiple_query_result(class_and_texts_map, probs)
predicted_labels.append(r_class_map[pred_class_label])

위 코드는 다음과 같은 방식으로 작동합니다:

  1. merge_multiple_query_result 함수:
    - 이 함수는 각 클래스에 대한 여러 쿼리의 확률을 합산합니 다.
    - class_and_texts_map에서 각 클래스와 관련된 쿼리들을 순회하면서, 해당 쿼리의 확률을 클래스별로 누적합니다.
    - 최종적으로 가장 높은 누적 확률을 가진 클래스를 예측 결과로 반환합니다.
  2. 분류 과정:
    -
    테스트 데이터셋의 각 이미지에 대해:
    a. do_prompt_based_image_classification 함수를 호출하여 5개의 쿼리 템플릿(N=5)을 사용한 분류를 수행합니다.
    b. 이 함수는 각 쿼리에 대한 확률과 함께 클래스-텍스트 매핑 정보를 반환합니다.
    c. merge_multiple_query_result 함수를 사용하여 여러 쿼리의 결과를 종합하고 최종 예측 클래스를 결정합니다.

이 접근 방식의 주요 장점은 다음과 같습니다:

  1. 다양한 관점 고려: 여러 쿼리 템플릿을 사용함으로써, 이미지의 다양한 측면을 고려할 수 있습니다. 예를 들어, “This is an image of a dog”“A furry animal with four legs is shown” 같은 서로 다른 표현이 동일한 개념을 설명할 수 있습니다.
  2. 견고성 향상: 단일 쿼리에 의존하는 대신 여러 쿼리의 결과를 종합함으로써, 개별 쿼리의 오류나 편향에 덜 민감해집니다.
  3. 정확도 개선: 다양한 쿼리의 결과를 합산하여 최종 예측을 내림으로써, 전반적인 분류 정확도를 향상시킬 수 있습니다.
  4. 유연성: N 값을 조절함으로써 사용할 쿼리 템플릿의 수를 쉽게 변경할 수 있습니다. 이를 통해 성능과 계산 비용 사이의 균형을 조절할 수 있습니다.
  5. 해석 가능성: 각 쿼리의 기여도를 분석함으로써, 모델의 결정 과정을 더 잘 이해할 수 있습니다.

하지만 이 방법에도 주의해야 할 점이 있습니다:

  1. 계산 복잡도: 여러 쿼리를 처리해야 하므로 단일 쿼리 방식보다 계산 시간이 증가합니다.
  2. 쿼리 설계의 중요성: 사용되는 쿼리 템플릿의 품질과 다양성이 최종 성능에 큰 영향을 미칩니다.

이제 실제 성능을 보죠.

# Calculate metrics
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
accuracy = accuracy_score(true_labels, predicted_labels)
precision, recall, f1, _ = precision_recall_fscore_support(true_labels, predicted_labels, average='weighted')

print(f'Accuracy: {accuracy:.4f}')
print(f'Precision: {precision:.4f}')
print(f'Recall: {recall:.4f}')
print(f'F1-score: {f1:.4f}')

# Generate the confusion matrix
cm = confusion_matrix(true_labels, predicted_labels)

# Plot the confusion matrix
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.title('Confusion Matrix')
plt.show()
Accuracy: 0.9400 
Precision: 0.9436
Recall: 0.9400
F1-score: 0.9398
Multiple queries — Confusion Matrix of prompt-based classification results with CLIP

200개 샘플에 대해서 수행해 보았을 때 미세하게 성능이 올라가는 것을 확인할 수 있습니다(F1 score : 0.9347 → 0.9398). 만약 더 많은 템플릿을 잘 설계해서 적용하고 또 데이터 샘플을 다양화하거나 과일 이미지보다 더 복잡하고 난해한 이미지에 대한 분류를 수행한다면 더 큰 성능향상 폭을 경험하실 수 있을 것입니다.

결과 분석 및 개선 방안

실험 결과, 우리의 프롬프트 기반 이미지 분류 모델은 상당히 좋은 성능을 보여주었습니다. 하지만 몇 가지 개선의 여지가 있습니다:

  1. 프롬프트 다양화: 각 클래스에 대해 더 다양한 프롬프트를 사용하면 성능이 향상될 수 있습니다.
  2. 앙상블 기법 적용: 여러 프롬프트의 결과를 종합하여 최종 예측을 수행하는 방식을 고려해볼 수 있습니다.
  3. 모델 fine-tuning: 특정 도메인의 이미지에 대해 CLIP 모델을 fine-tuning하면 더 좋은 성능을 얻을 수 있습니다.
  4. 데이터 증강: 훈련 데이터에 다양한 증강 기법을 적용하여 모델의 일반화 능력을 향상시킬 수 있습니다.

결론

이번 실습을 통해 우리는 CLIP 모델을 활용한 프롬프트 기반 이미지 분류 방법을 구현하고 평가해보았습니다. 이 방법의 주요 장점은 다음과 같습니다:

  1. 높은 성능: 별도의 fine-tuning 없이도 과일 이미지 분류 태스크에서 0.93 이상의 F1 점수를 달성했습니다. 이는 프롬프트 기반 접근법의 효과성을 잘 보여줍니다.
  2. 유연성: 새로운 클래스를 쉽게 추가할 수 있으며, 프롬프트 수정만으로도 다양한 분류 작업에 적용할 수 있습니다.
  3. 제로샷 학습: 특정 태스크에 대한 추가 학습 없이도 다양한 이미지 분류 작업을 수행할 수 있었습니다.
  4. 앙상블 효과: 다중 쿼리 템플릿을 사용함으로써 단일 쿼리 방식보다 더 나은 성능(F1 점수 0.9347에서 0.9398로 향상)을 얻을 수 있었습니다.
  5. 해석 가능성: 모델의 예측 과정을 쉽게 이해하고 분석할 수 있어, 결과에 대한 투명성을 제공합니다.

하지만 이 접근법에도 몇 가지 한계와 개선 가능성이 있습니다:

  1. 계산 복잡도: 여러 쿼리를 처리해야 하므로 단일 쿼리 방식보다 계산 시간이 증가합니다.
  2. 프롬프트 설계의 중요성: 사용되는 쿼리 템플릿의 품질과 다양성이 최종 성능에 큰 영향을 미치므로, 효과적인 프롬프트 설계가 중요합니다. 결국 개발자의 노력이 들어가긴 하는 거죠.
  3. 도메인 특화 성능: 특정 도메인에 대해서는 fine-tuning을 통해 더 높은 성능을 얻을 수 있을 것입니다.

결론적으로, CLIP을 활용한 프롬프트 기반 이미지 분류 방법은 높은 성능과 유연성을 제공하며, 특히 제로샷 학습 능력을 통해 다양한 실제 응용 분야에서 큰 잠재력을 보여줍니다. 이는 전통적인 지도 학습 기반의 이미지 분류 방법과 비교하여, 더 적은 데이터와 학습 시간으로도 효과적인 결과를 얻을 수 있는 완전히 새로운 형태의 접근법입니다.

이번 글에서 구현한 코드는 아래 Colab 링크에서 직접 실행해 볼 수 있습니다.

--

--

Hugman Sangkeun Jung

Hugman Sangkeun Jung is a professor at Chungnam National University, with expertise in AI, machine learning, NLP, and medical decision support.