Multi-Task Learning as Multi-Objective Optimization

Inwan Yoo
Lunit Team Blog
Published in
6 min readJun 5, 2020

이 글은 2019년 1월, 루닛 블로그에 올렸던 포스트입니다. (https://blog.lunit.io/2019/01/30/multi-task-learning-as-multi-objective-optimization/)

Introduction

사람은 방의 사진을 볼 때 방의 구조가 어떻게 되고, 어떤 물건들이 있고, 그것들이 현재 카메라의 위치에서 얼마나 떨어져 있는지를 동시에 파악합니다. 우리는 사진을 단순히 2차원의 그림이나 패턴으로 생각하지 않습니다. 왜냐면 단순히 관측된 사진의 정보보다 더 많은 현실에서의 정보를 실제로 그 공간에 살면서 얻기 때문입니다.

오늘 소개할 “Multi-Task Learning as Multi-Objective Optimization” (NeurIPS 2018) 논문의 저자들은 Stein’s paradox를 예시로 들어 multi-task learning (MTL)의 중요성을 설명합니다. Stein’s paradox란 셋 또는 그 이상의 Gaussian random variable의 평균을 구할 때 모든 Gaussian 전체로부터 sampling하는 것보다, 각각의 Gaussian에서 독립적으로 sampling하는 것이 전체 분포의 평균을 더 잘 측정할 수 있다는 현상을 의미합니다. 방의 사진은 방이 내포하는 다양한 정보의 분포들로부터 얻은 하나의 sample입니다. 이를 뭉뚱그려진 ‘방 사진’의 분포로 이해하는 것보다, ‘방의 구조’ / ‘방의 물건들’ / ‘내 위치에서 방에 있는 물체들의 거리’ 등의 다양한 분포를 통해 추론하는 것이 더 좋은 추정을 할 수 있습니다.

Pareto optimality & Karush-Kuhn-Tucker (KKT) condition

딥러닝에서 multi-task learning은 보통 개별 task t에 대한 shared parameter θ^sh 및 task-specific parameter θ^t로 이루어진 각 task loss를 가중치 c^t

로 가중합(weighted sum)된 전체 loss를 줄이는 방식으로 이루어집니다. 이러한 정의는 직관적이지만, c^t를 구하기 위해 계산 비용이 큰 grid search 방식을 이용하거나, 이론적인 근거 없이 직관에 의존한 (heuristic) 방법이 적용되었습니다. 또, 이런 전체 loss의 전역 최적해를 정의하는 것은 사실상 불가능합니다. 어떤 해 θ는 특정 task1에서는 잘 동작하나, 다른 task2에서는 잘 동작하지 않을 수 있습니다. 또 다른 해 θ’는 오히려 task2에서는 잘 동작하지만 반면에 task1에서는 잘 동작하지 않을 수 있습니다. 이를 비교하기 위해서는 각 task 쌍들의 중요도를 비교할 수 있어야 하는데, 이는 일반적을 불가능합니다. 위의 방의 예시로 보자면, ‘방의 물건들’을 아는 것과 ‘방의 구조’를 아는 것이 서로 ‘얼마나’ 중요한 정보일지는 알기 어렵습니다. 때문에 저자들은 가중합 loss가 아닌 ‘vector-valued’ loss를 정의하여 사용합니다.

위부터 각각 task t에 대한 loss 함수, 각 task loss의 가중합에 대한 최적화 문제, vector-valued loss의 multi-objective 최적화 문제.

이렇게 loss를 vector로 만들면 무엇이 더 좋은 해인지 알 수 없습니다. 저자들은 MTL을 위한 Pareto optimality를 다음과 같이 정의합니다.

  1. 어떤 해가 다른 해보다 모든 task에서 성능이 더 좋다면 이 해는 다른 해를 압도(dominate)한다고 정의한다.
  2. 어떤 해를 압도하는 다른 해가 없다면 이 해는 Pareto optimal 한다.

Deep learning 방법은 gradient 기반 최적화에 의존하므로, 이를 위한 KKT 조건을 다음과 같이 정의합니다.

즉, KKT 조건이란 각 task에 대하여 task-specific parameter들은 수렴하였고 어떤 task 가중치에 대해 shared parameter도 수렴한 상황을 의미합니다. 저자들은 이런 상태의 해를 Pareto stationary point라고 정의합니다. 여기서 각 task의 KKT 조건 1의 가중합이 0이 되면, 최소한 KKT 조건을 만족하거나, 못해도 각 task-specific parameter를 학습시켜 모든 task를 더 최적화할 수 있습니다.

Multiple Gradient Descent Algorithm (MGDA)

MTL을 위한 MGDA는 다음과 같은 방식으로 진행됩니다.

위의 과정을 반복하면 모델은 KKT 조건을 만족할 수 있도록 학습됩니다.

MGDA에서 사용되는 Frank-Wolfe solver는 빠르기는 하나, feed-forward 자체의 횟수가 많아 다소 비효율적입니다. 저자들은 MGDA의 가중치 α를 구하기 위한 gradient 계산을 shared parameter가 아닌 shared parameter로 구성된 network인 encoder g의 결과 z=g(x)로 두고 같은 방식으로 학습하는 MGDA-UB를 구현합니다. 또, 이를 앞에서 정의한 방법으로 대체할 수 있음을 증명하고, 이것으로 실험을 진행합니다.

Conclusion

저자들은 크게 CelebA, MultiMNIST, 그리고 Cityscapes 세 개의 dataset에 대하여 실험하고, 모두 비교군인 grid search, uniform scaling, GradNorm 등의 다른 방법들보다 좋은 결과를 얻습니다. 특이한 것은 MGDA보다 low-dimensional approximation인 MGDA-UB가 더 좋은 성능을 얻는 점입니다. 저자들은 MGDA-UB가 더 MGDA보다 성능을 낮추지 않으며, 더 작은 차원에서 최적화를 하기 때문에 더 빠르고 안정적으로 학습이 가능했다고 설명합니다. 이전까지 deep learning 분야에서 MTL에 대한 이론적인 접근이 많이 없었는데, 이 논문에서는 Pareto optimality라는 개념을 통해 좀 더 명확한 이론적 근거를 대고 이를 실험적으로 증명합니다.

--

--