Unsupervised Data Augmentation for Consistency Training

Damian
mojitok
Published in
11 min readNov 14, 2019

최근 딥러닝의 성공 요인으로는 알고리즘의 발전, GPU/TPU 등 하드웨어의 발전, 그리고 ImageNet을 비롯한 레이블링 된 대량의 데이터를 꼽을 수 있습니다. 하지만 대량의 레이블 데이터를 구하는 건 매우 어렵습니다. 연구자들은 이러한 문제를 해결하기 위해 준지도학습 (Semi-supervised Learning, SSL)이나 supervised data augmentation 등의 방법을 연구해왔습니다. 본 포스트에서는 두 방법의 장점을 합쳐 매우 뛰어난 성과를 거둔 논문 “Unsupervised Data Augmentation for Consistency Training”을 소개합니다.

Note: 아래의 코드는 구글리서치의 공식 구현을 조금 수정한 것입니다.

Background

준지도학습의 방법 중 하나인 consistency/smoothness enforcing은 모델의 예측이 입력값의 작은 변화에 민감하게 반응하지 않도록 강제하는 방법으로, 일반적으로 입력값에 Gaussian noise나 Drop-out을 적용한 후 두 입력값에 대한 모델의 아웃풋이 유사하도록 학습합니다.

Data augmentation이란 일반적으로 지도학습 (Supervised Learning)에서 사용되는 방법으로, 레이블이 존재하는 데이터에 변화를 줘 원본 데이터와 같은 레이블을 가지는 새로운 데이터를 만드는 방법입니다. Data augmentation의 예시로는 이미지 데이터의 경우 이미지 회전, 확대 등이 있고, 텍스트 데이터의 경우 paraphrasing 등이 있습니다.

Unsupervised Data Augmentation

Unsupervised Data Augmentation for Consistency Training에서는 위의 두 방법을 결합합니다. 일반적인 consistency enforcing 방법은 레이블이 없는 입력값에 랜덤 노이즈를 적용하여 새로운 입력값을 만들지만, 본 논문에서는 data augmentation 방법을 사용하여 새로운 입력값을 만들어냅니다.

본 논문에서는 이미지 분류 문제와 텍스트 분류 문제에 각각 적용할 수 있는 data augmentation 방법 세 가지를 소개합니다.

  • AutoAugmentation (이미지 분류): 강화학습을 활용하여 최적의 이미지 augmentation 정책을 찾는 방법입니다. 자세한 설명은 논문을 참고해주세요.
  • Back translation (텍스트 분류): 두개의 기계번역 모델을 활용하여 원본 텍스트와 의미적으로 유사한 다양한 텍스트를 얻는 방법입니다. 논문의 저자는 영어-프랑스어와 프랑스어-영어 번역 모델을 사용하였습니다.
  • TF-IDF 기반 단어 대체 (텍스트 분류): 위키피디아 문서의 카테고리를 예측해야 하는 DBPedia 문제의 경우 키워드 단어를 보존하는 것이 매우 중요합니다. 따라서 문서의 TF-IDF 벡터에서 값이 낮은, 즉 설명력이 낮은 단어를 다른 단어로 대체하는 방식을 선택하였습니다.

Objective

논문에서 소개하는 방법을 그림으로 표현하면 다음과 같습니다.

목적함수는 다음과 같이 정의됩니다.

overall objective of UDA
‘unsupervised’ objective of UDA

즉, 레이블이 있는 데이터의 cross-entropy loss와, 레이블이 없는 입력값과 data augmentation을 적용하여 얻은 새로운 입력값의 두 아웃풋 사이의 KL divergence, 이 둘의 합을 최소화하는 것이 본 논문에서 제시하는 objective입니다. (논문에서는 위 목적함수의 앞부분을 supervised objective, 뒷부분을 unsupervised objective로 지칭합니다.)

하지만 여러 문제로 인해 이 방법을 곧바로 적용하기는 어렵습니다. 우선, 레이블이 있는 데이터의 크기가 상대적으로 많이 작기 때문에 모델이 빠르게 과적합되는 문제가 발생할 수 있습니다. 또, unlabeled 데이터에 대한 모델의 예측이 평평하여 (over-flat) 학습이 잘 되지 않을 수 있다는 문제가 있습니다. 마지막으로 unlabeled 데이터와 labeled 데이터의 분포가 많이 다를 경우, unlabeled 데이터를 추가하는 것이 오히려 모델의 성능을 떨어트릴 수 있다는 문제가 있습니다.

이러한 문제를 해결하기 위해 논문에서는 다양한 테크닉을 소개합니다.

Additional Training Techniques

Training Signal Annealing (TSA)

대규모의 unlabeled 데이터를 잘 활용하기 위해서는 모델의 크기 역시 커야하지만, 모델의 크기가 클수록 소량의 labeled 데이터에 과적합되기도 쉽습니다. 이러한 문제를 해결하기 위해 논문에서는 Training Signal Annealing (TSA)라는 방법을 소개합니다.

TSA는 과적합을 막기 위해 학습 초반에는 정답 레이블에 대한 confidence가 높은 labeled 데이터를 학습에 이용하지 않다가, 학습이 진행되면서 confidence가 높은 labeled 데이터도 점진적으로 학습에 사용하는 방법입니다. TSA는 목적함수의 supervised objective를 다음으로 바꾸는 방법으로 구현됩니다.

supervised objective with TSA

η는 threshold로, 1/K와 1 사이의 값을 가집니다. (K: 카테고리의 개수). η는 1/K에서 시작하여 점진적으로 1까지 커지게 되는데, 논문에서는 η를 증가시키는 방법으로 세가지 schedule (log-schedule, linear-schedule, exp-schedule)을 소개합니다.

comparison of three threshold scheduling methods

TSA를 적용한 supervised objective의 공식 텐서플로우 구현은 다음과 같습니다.

Official implementation of UDA supervised objective + comments

Sharpening Predictions

위에서 언급한 문제 중 unlabeled 데이터에 대한 모델의 예측이 평평하여 (over-flat) 학습이 잘 되지 않을 수 있다는 문제는 어떻게 해결할 수 있을까요? 논문에서는 세가지 방법을 소개합니다.

  • Confidence-based masking: 예측의 confidence가 낮은 unlabeled 데이터를 학습에 이용하지 않습니다.
  • Entropy minimization: augmented 데이터의 예측값이 낮은 entropy를 가지도록 (=예측이 더 sharp하도록) entropy objective term을 전체 objective에 추가합니다.
  • Softmax temperature controlling: unlabeled 데이터의 예측값을 계산할 때 1 미만의 softmax temperature를 적용하여 augmented 데이터의 타겟이 더 sharp해지도록 합니다.

논문 저자에 따르면 labeled 데이터가 매우 적을 경우 confidence-based masking과 softmax temperature controlling이 유용하고, 상대적으로 labeled 데이터가 많은 경우 entropy minimization이 효과가 있다고 합니다.

공식 텐서플로우 구현은 다음과 같습니다. (공식 구현에는 entropy minimization이 생략되어 있습니다.)

Official implementation of UDA unsupervised objective + comments

Domain-relevance Data Filtering

unlabeled, out-of-domain 데이터는 labeled 데이터보다 더 구하기 쉽지만, 두 데이터의 클래스 분포가 크게 다를 경우 오히려 학습을 방해할 수 있습니다. 이런 문제를 완화하기 위해 논문의 저자들은 우선 모델을 labeled 데이터로 학습한 후, unlabeled 데이터 중 confidence가 높은 데이터만을 사용하는 필터링 기법을 활용하였습니다.

Experiments

텍스트 분류 위주로 실험 결과를 살펴보겠습니다.

논문에서는 4가지 모델 구조로 실험을 진행했습니다. 4가지 모델은 각각 1) 무작위로 초기화 된 Transformer 구조의 모델 (Random), 2) BERT-Base, 3) BERT-Large, 4) BERT-Large를 unlabeled 데이터로 추가 학습한 BERT-Finetune이 되겠습니다. (Transformer에 대해서는 저희의 이전 포스트를, BERT에 대해서는 논문코드를 참고해주세요.)

실험의 결과는 아래와 같습니다.

UDA experiments result — text classification
  • UDA 방법을 활용하면 어떤 모델이든 성능이 향상됨을 알 수 있습니다.
  • UDA 방법을 활용하면 훨씬 적은 labeled 데이터만을 가지고도 기존의 결과에 근접한 성능을 얻을 수 있었습니다. 특히 IMDb 데이터셋과 Yelp-2 데이터셋의 경우, 각각 전체 labeled 데이터의 1250분의 1, 28000분의 1만을 사용하였지만 오히려 성능은 Pre-BERT SOTA (state-of-the-art) 보다 더 좋아졌습니다. (BERT-Finetune 기준)
  • 5개의 카테고리로 분류해야 하는 Yelp-5와 Amazon-5 데이터셋에서는 UDA 방법이 모델의 성능을 향상시키기는 하지만, fully supervised baseline과의 격차가 크게 나타남을 확인할 수 있습니다.

이에 더해 논문에서는 다양한 labeled 데이터셋 크기에서 일반적인 지도학습 방법과 UDA를 사용하는 준지도학습 방법을 비교해 보았는데, 그 결과는 다음과 같았습니다.

UDA vs supervised with various labeled data size
  • labeled 데이터셋의 크기가 커질수록 두 방법의 성능 격차가 줄어들지만, 여전히 UDA를 사용하는 준지도학습 방법이 더 좋은 성능을 보였습니다.

이미지 분류 문제의 경우 실험 결과는 생략하겠습니다. 자세한 내용은 논문을 참고해주세요.

Ablation Study

마지막으로 위에서 소개했던 TSA (Training Signal Annealing)의 효과에 대한 ablation study를 보겠습니다. TSA의 효과를 살펴보기 위해 Yelp-5 데이터셋과 CIFAR-10 데이터셋에서 세가지 TSA schedule을 테스트했는데, 결과는 다음과 같습니다.

Yelp-5 데이터셋에서는 어떤 schedule을 사용하든 성능이 좋아지는 것을 확인할 수 있고, CIFAR-10 데이터셋에서는 exp-schedule을 제외한 두 schedule에서 성능이 좋아졌습니다.

Yelp-5 데이터셋의 경우 labeled 데이터에 비해 unlabeled 데이터가 훨씬 많기 때문에 exp-schedule의 성능이 가장 좋은 것으로 추측되며, 반대로 CIFAR-10 데이터셋의 경우 unlabeled 데이터가 상대적으로 적기 때문에 exp-schedule이 오히려 학습을 방해한 것으로 추측됩니다.

마치며

프로덕션에서 딥러닝을 적용할 때 가장 어려운 부분 중 하나가 바로 레이블링 된 데이터의 부족이 아닐까 싶습니다. 이 포스트가 데이터 부족으로 고생하시는 많은 분들께 도움이 되길 바랍니다.

--

--