Dataset Distillation- 데이터를 증류해서 압축하기

scalalang2
CURG
Published in
7 min readOct 9, 2020

--

빅데이터(Big Data)라는 용어는 사회적인 유행어처럼 쓰이고 있다. 몇몇 이들은 많은 데이터를 쌓이 두기만 하면 이를 활용해 새로운 비즈니스를 창출 하고 부를 이룰 수 있다고 믿고 있다. 하지만, 단순히 데이터가 많은 것보다 중요한 건 좋은 데이터이다. 몇 가지 예시를 보자.

2016년, 마이크로소프트는 채팅 봇 테이를 선보이고 16시간 만에 운영을 중단했다. 일부 사람들이 테이에게 인종차별, 혐오표현, 성차별등의 악의적인 발언을 사용하도록 훈련시켜 테이가 자극적인 표현을 남발했기 때문이다. 이 처럼 악의적으로 나쁜 학습 데이터를 주입해서 딥러닝 모델을 의도한 바와 다르게 행동하게 하는 것을 중독 공격(Poisoning Attack)이라고 부른다.

2018년, 세계 최대 컴퓨터 비전 학회인 CVPR에서는 자율주행 자동차가 표지판 분류를 위해 사용하는 알고리즘의 취약점을 이용해서 ‘정지' 표지판을 보고도 그냥 지나치게 하는 공격에 성공한 논문이 제출되었다[1]. 이 공격 방법은 단순히 정지 표지판에 몇 개의 스티커를 부착하는 것만으로도 성공했는데 사람의 시선에서는 오해할 가능성이 현저히 적지만 해당 딥러닝 모델은 이 표지판을 ‘속도 제한'으로 인식하였다.

[그림 1] 정지 표지판에 스티커를 부착하는 것 만으로도 딥러닝 모델은 이를 오해한다.

이제 우리는 빅데이터도 중요하지만 잘 정제된 좋은 데이터를 얻는 것도 중요한 영역임을 알았다. 좋은 모델을 얻기 위해서는 많은 데이터를 이용해 학습하는 것도 좋지만 잘 압축된 적은 데이터로도 비슷한 성능의 모델을 낼 수 있지 않을까? 이번 글에서는 데이터 셋을 증류해서 데이터를 압축하는 방법인 Dataset Distillation를 소개한다. 이 방법은 Facebook AI Research팀의 Tongzhou Wang와 CycleGAN을 고안한 Jun-Yan Zhu가 연구하였다[2].

지식 증류 (Knowledge Distillation)

지식 증류는 Hinton교수가 제안한 방법으로 엄청나게 큰 앙상블 모델의 지식을 더 작은 모델(compact network)로 전달하는 방법이다. Dataset Distillation또한 대량의 데이터를 소수의 데이터로 압축할 수 있지 않을까? 하는 아이디어에서 출발한다. 아래 [그림 2]는 대략적으로 이 논문에서 하고자 한 일을 보여준다.

[그림 2] Dataset Distillation의 효과

[그림 2]에서 MNIST는 10개의 클래스를 가진 0~9 숫자가 적힌 이미지 데이터이다. 이를 각 클래스 별로 1개씩 압축해서 모델을 학습한 결과 테스트 단계에서 94%의 정확도를 보인다. 이 처럼 중복된 데이터를 제거하고 특이한 데이터의 특징을 이미지에 압축하고 이를 학습에 이용하는 방법이 Dataset Distillation이다. 물론 CIFAR10 학습 데이터의 경우 압축 이후 테스트 단계에서 정확도는 54% 정도로 아직 추가적인 연구가 필요하다. 다음 섹션에서는 이 방법이 어떻게 동작하는지 설명한다.

데이터 셋 증류 (Dataset Distillation)

딥러닝의 학습은 입력값인 인풋이 깊은 네트워크를 통해 최종 레이어로 도달했을 때 결과값을 편미분 해가면서 손실을 뒤로 전달하는 오차역전파 원리에 의해 동작한다. 이를 수식화해서 표현하면 아래 [그림 3]과 같다.

[그림 3] 오차역전파를 통해 파라미터인 θ값을 변경한다.

위 수식에서 theta(θ)값은모델의 파라미터, x는 데이터, mu는 학습율, L은 손실 함수를 나타낸다. 데이터를 증류하는 목적은 적은 수의 데이터만으로도 전체 데이터를 학습한것과 동일한 효과를 내기위해서이다. 그래서 이 논문에서는 데이터 증류 문제를 아래 [그림 4]와 같이 정의한다.

[그림 4] Data Distillation에서 하고자 하는 목표식

결국 전체 데이터 집합인 x의 loss와 데이터를 증류해서 나온 _x를 이용했을 때의 loss가 서로 최대한 낮게 나오는 _x를 찾는 문제로 귀결한다. 하지만 위의 목적식에서 계산된 _x는 원본 데이터 x뿐 아니라 모델의 파라미터인 θ값도 영향을 미치기 때문에 고정된 θ값을 사용하면 노이즈가 잔뜩 낀 데이터로 압축되는 문제가 있다. 그래서 매 순간 θ값을 새로 랜덤으로 뽑아내는 방법으로 _x값을 일반화 할 수 있도록 했다. 결국 데이터 증류 문제는 최종적으로 아래 [그림 5]와 같이 정의한다.

[그림 5] Data Distillation의 최종 문제 정의: θ값이 증류에 영향을 미칠 수 없도록 매 순간 랜덤으로 생성한다.

데이터 증류 과정에서는 실제 전체 데이터의 손실 함수의 값과 압축 데이터의 손실 함수의 값이 동일하게 감소하도록 모델을 훈련하고, 손실 함수를 _x에 대해 미분한 값을 이용해 _x값을 업데이트 하는 방식으로 동작한다. 아래 [그림6]의 의사 코드를 보면 더 자세하게 이해할 수 있다.

[그림 6] 데이터 증류 의사코드

데이터 증류 문제는 전체 N개의 데이터 셋을 정해진 M개의 데이터 셋으로 압축하는데 있다. 매 학습 주기 t마다 모델 파라미터 θ값을 랜덤 확률 분포에서 새롭게 샘플링 하여 대입한다. 그리고 압축 데이터 _x 에 대해 경사 하강 알고리즘(Gradient Descent)를 수행한다. 이렇게 계산된 새로운 θ_1과 실제 데이터의 오차(L)를 계산하고 이 값을 _x에 대해 미분한 값을 이용해 _x값을 업데이트 한다. 아래 실험 결과는 이 알고리즘을 이용하여 압축한 데이터를 보여준다.

[그림 7] 데이터 증류 실험 결과

MNIST의 경우 데이터 증류 결과 각 클래스 별로 10개의 데이터를 증류하였고 이를 이용해 학습을 한 결과 93.76%의 정확도를 가진 모델을 만들 수 있었다. MNIST의 경우 어느 정도 형체를 알아볼 수 있는 이미지로 나왔는데, 반면 CIFAR의 경우에는 특징을 알아보기 힘든 형태의 노이즈한 이미지가 증류되어 나왔다. 이 경우에도 학습 결과 54.03%의 정확도를 가진 모델을 만들 수 있었다.

의도적으로 악의적인 데이터 생성하기

데이터 증류 방법을 모델에게 혼란을 주는 데이터를 생성하는 악의적인 목적으로 사용할 수 있다. 단순히 공격 대상이 되는 클래스 K를 전혀 상관없는 클래스 T로 설정하고 데이터를 증류하는 것만으로도 이 공격은 성공한다. 아래 [그림 8]은 이를 표현하는 목적식을 나타내고, [그림 9]는 이를 이용해서 모델을 공격하는 상황을 그림으로 보여준다.

[그림 8] 악의적인 데이터를 생성하는 데이터 증류 목적식
[그림 9] 데이터 증류를 이용한 데이터 중독 공격

결론

Hinton 교수의 모델의 지식 증류 이후로 증류라는 이름 붙은 다양한 논문이 제출되고 있다. 이번에 선택한 논문 또한, 이를 이용해서 흥미로운 새로운 연구 분야를 만들어냈다는 점에서 재밌게 읽을 수 있었다. 하지만 실험 결과가 다소 부족하고, 데이터 증류를 사용하면 좋다라고 주장하기에는 많이 빈약해 보였다. 앞으로 이쪽 분야에서도 추가적인 연구가 필요할 듯 하다.

레퍼런스

[1] Robust Physical-World Attacks on Deep Learning Visual Classification
[2] Dataset Distillation

--

--

scalalang2
CURG
Editor for

평범한 프로그래머입니다. 취미 논문 찾아보기, 코딩 컨테스트, 언리얼 엔진 등 / Twitter @scalalang2 / AtCoder @scalalang