StarGAN — 단일 모델로 다중 도메인 이미지 변환기 만들기

scalalang2
CURG
Published in
11 min readJun 26, 2020

이번 포스팅에서는 CVPR 2018 에서 발표된 논문인 StarGAN[1]을 소개하고자 한다. 최근 CVPR 2020에서 StarGAN-v2가 새롭게 발표 되었는데 이번 포스팅에서는 StarGAN-v1을 먼저 다루고, 다음 포스팅에서는 바로 이어서 StarGAN-v2를 정리하려고 한다.

구성

  • 이미지 변환
  • StarGAN이란 무엇인가?
  • 교차-도메인 모델 (Cross-domain Models)
  • StarGAN의 단일 모델 (Unified Model)
  • 다중 데이터 셋을 통합하여 학습하기
  • 결론

이미지 변환

이미지 변환이란 어떤 이미지를 특정한 특징을 가진 이미지로 변환하는 것을 말한다. 예를 들어 무표정인 사람의 모습을 웃는 모습으로 바꾼다거나 혹은 성별을 바꾸는 등 다양한 특징을 가지게 변환시키는 작업을 “이미지 변환", 변환할 때 적용하는 특징을 “도메인" 이라고 부른다. 딥러닝 분야에서 이미지 변환기의 대표적인 예로 Cycle-GAN 이 있다[2].

StarGAN이란 무엇인가?

이 논문에서는 하나의 단일 뉴럴 네트워크만 학습해서 사용해도 다중 도메인을 가진 이미지 변환을 할 수 있는 방법을 제시한다. 이 논문의 첫 페이지에 실린 그림을 보면 Abstract를 읽기도 전에 이 논문의 Contribution이 무엇인지 알 수 있다.

[그림 1.] CelebA와 RaFD 데이터 셋을 StarGAN으로 학습하여 이미지를 다른 도메인으로 변환한 결과

그림 1 에서는 4x9 행렬에 이미지가 채워져 있는데 이 중 Input이 원본 이미지이다. 이 원본 이미지를 각각 [Blond hair, Gender, Aged, Pale Skin] 그리고 [Angry, Happy, Fearful] 도메인으로 변환된 결과가 Input의 오른쪽에 나타난다. 도메인을 2개로 구분된 이유는 학습한 데이터 셋이 다르기 때문이다. 뒤에 다시 기술하겠지만 위의 결과는 단일 모델에서 2개의 데이터 셋을 동시에 학습하여 출력한 결과이다. StarGAN 논문에서는 해결하고자 한 문제와 기여는 아래와 같다.

  • 최근 연구에서 사용되는 도메인-변환 기법은 두 개 이상의 도메인을 학습할 때 Scalability와 Robustness가 부족하다.
  • StarGAN은 하나의 통합된 모델을 사용해서 다중-도메인과 다양한 데이터셋을 하나의 네트워크에서 학습할 수 있게 해준다.

교차 도메인 모델 (Cross-domain Models)

기존 연구[CycleGAN, DiscoGAN, pix2pix, cGAN 등]에서는 한 개의 특징만을학습해서 변환하는 방법을 제시한다. 이 방법으로 구현된 신경망은 웃는 얼굴을 우는 얼굴로 바꾸는 작업밖에 하지 못한다. 이런 환경에서 다양한 도메인으로 변환하려면 도메인 K당 K(K-1)개의 네트워크가 필요하다.

[그림 2] 크로스 도메인 모델과 StarGAN의 아키텍처

[그림 2]는 논문에서 나온 그림으로 교차 도메인 모델과 StarGAN 모델의 아키텍처를 보여준다. 그림의 왼쪽에 있는 그림이 기존 모델인 크로스 도메인 모델인데 이 모델에서 숫자가 적힌 원은 하나의 도메인을 나타내고 G_ij 는 도메인 i에서 도메인 j로 변환하는 하나의 신경망을 나타낸다. 이런 모델에서 각각의 생성망 G는 전체 데이터 셋을 학습하기는 하지만 특정 도메인에 대해서만 학습하기 때문에 연산력 측면에서 낭비가 발생한다. 그리고 특정 도메인만 특화해서 학습하기 때문에 일반적인 정보를 얻지 못해서 생성하는 이미지의 질적 차이가 존재한다.

StarGAN의 단일 모델 (Unified Model)

반면, [그림2]의 오른쪽 모델인 StarGAN은 하나의 신경망을 이용해서 많은 도메인으로 변환하기 때문에 일반적인 지식을 학습하여 더 높은 퀄리티의 이미지를 생성한다. 또한, 연산력을 비용으로 생각한다면 매우 경제적인 모델이다. 지금부터는 StarGAN의 아키텍처와 손실 함수에 대해 알아 볼 것이다.

[그림 3] StarGAN의 아키텍처

그림 3. 은 StarGAN의 아키텍처를 보여준다. 이 아키텍처를 보면 우리는 몇 가지 사실을 알 수 있다. 먼저 기존 GAN에서는 잠재 변수 z를 입력값으로 받는 반면, StarGAN에서는 변환하고자 하는 도메인 정보(c)와 원본 이미지(x)를 입력값으로 받는다. 원본 이미지를 입력값으로 받는 건 변분 오토 인코더 (VAE)가 사용된 UNIT [3]모델에서 아이디어를 차용했다. 그리고 판별기는 원본 이미지의 Real/Fake여부에 더해서 특정 도메인 정보까지 맞추는 걸 목표로 한다 다음으로는 StarGAN에서 정의한 손실함수를 다룰 텐데 그 전에 손실함수가 이루고자 하는 목적을 다시 한 번 상기해보자.

  • 하나의 손실 함수를 통해 많은 도메인을 학습할 수 있어야 한다.
  • 도메인을 학습하면서도 이미지의 퀄리티는 잃지 않아야 한다.
  • 원본 이미지와 타겟 도메인이 주어지면 원본 이미지를 타겟 도메인으로 변환 할 수 있어야 한다.

Loss Function

StarGAN의 손실 함수는 기존 적대적 생성망의 손실 함수를 그대로 사용하고 몇 가지 계산을 추가하였다. 아래 그림은 손실 함수 각 부분을 나눈 것 이다. 적대적 생성망에는 Discriminator 의 손실 함수인 L_D와 Generator 손실 함수인 L_G, 이 두개의 손실 함수를 정의해야 하는데 StarGAN에서는 아래 네 가지 함수를 먼저 정의하고 이를 조합하여 L_D와 L_G 함수를 정의한다.

[그림 4] StarGAN의 4가지 손실 함수
  • x : 원본 이미지를 나타낸다.
  • c : 바꾸고자 하는 도메인을 나타낸다. 예를 들어 무표정인 사람을 웃게하거나(happy), 슬프게 하는(sad) 등의 목적 도메인을 one-hot vector로 사용한다.
  • c’ : 변환 하기 전 원래 이미지의 도메인을 나타낸다.
  • Adversarial Loss: GAN의 MinMax-Game과 동일하게 정의한다. 한 가지 다른 점은 G의 입력 값으로 latent space를 이용하지 않고 원본 이미지인 x를 그대로 사용한다는 점이다. 논문에서는 일반화를 하기 위해 GAN의 MinMax-Game 수식으로 정의하긴 했지만 실제로 학습에 사용할 때는 Wasserstein-GAN[4]의 손실 함수를 이용한다.
  • Domain Classification Loss at Discriminator: StarGAN의 Discriminator는 주어진 이미지가 원본인지 가짜인지 구분하는데 더해 원본 이미지의 도메인 까지 예측해야 한다. 그 수식을 정의한게 그림 3.의 2번 Loss이다. 이렇게 손실 함수를 정의함으로써 신경망은 각 도메인의 특징까지 학습 할 수 있게 된다. (-log함수는 1에 가까우면 0으로 수렴하고 0에 가까우면 무한대로 발산하는 함수이다, Dcls(c’|x)가 잘 학습해서 1을 반환하면 loss는 줄어 들고 반대 경우에는 값은 무한대가 된다)
  • Domain Classification Loss at Generator: Generator 또한 도메인을 학습해야 한다. StarGAN 논문의 목적이 생성망에서 이미지를 잘 변환하는 것이기 때문에 이 부분이 중요한데, 2번은 원본 이미지가 도메인 c’로 잘 분류하게끔 학습 시킨다면, 3번은 생성기가 생성된 이미지가 타겟 도메인 c로 잘 분류되게끔 학습 시키는데 목적을 둔다.
  • Reconstruction Loss: 이 손실 함수는 변환되는 과정에서 이미지의 퀄리티를 지키기 위해 사용된다. CycleGAN의 Cycle Consistency Loss의 원리를 이용하였다. 하나씩 보면, G(x,c)=y에서 이미지를 타겟 도메인으로 변환하고, 이렇게 생성된 이미지를 G(y,c’)=x’를 통해 다시 원래 도메인으로 변환시킨다. 마지막으로는 원본 이미지 x와 복구된 이미지인 x’의 맨하탄 거리를 계산한다.

지금까지 이렇게 총 4가지 Loss를 정의했고 최종적으로 L_D와 L_G는 아래와 같이 조합해서 사용한다. 아래 그림에서 λ값은 하이퍼 파라미터 이다. 논문에서는 λ_cls = 1, λ_rec = 10으로 해서 Reconstruction Loss에 더 높은 패널티를 부여해서 사용한다. 마지막으로 Discriminator와 Generator는 논문의 부록에 실려있으며 무난하게 CNN으로 구성한다.

[그림 5] StarGAN의 판별기/생성기의 최종 손실 함수

다중 데이터 셋을 통합하여 학습하기

StarGAN 논문에서는 CelebA[5]와 RaFD[6] 데이터 셋을 하나의 모델에서 학습한 결과를 보여준다. CelebA는 사람 이미지를 헤어 스타일이나, 얼굴형 등 40여개의 특징으로 분류하였고, RaFD는 8개의 표정으로 사람 사진을 분류한 데이터 셋이다. 이 두개의 데이터 셋은 각각 다루는 도메인이 다루기 때문에 이를 통합해서 학습하려면 몇 가지 전략이 필요하다.

  • Mask Vector : 두 개의 데이터 셋의 Class를 통합하여 학습하기 위해 두 가지 클래스를 모두 만족하는 one-hot vector를 만들어서 사용하는데 이를 Mask Vector 라고 부른다.
  • CelebA는 데이터 셋이 200,000개 이상인 반면 RaFD는 500여개 밖에 안되기 때문에 그냥 배치로 학습하면 CelebA 데이터에 크게 편향될 수 있다. 그래서 여기에서는 CelebA는 10 epoch로 학습할 때 RaFD는 100 epoch로 학습하는 방법으로 두 데이터 셋을 균등하게 학습한다.

지금까지 StarGAN을 이해하기 위한 주요 내용을 서술하였다. 이제는 위 내용으로 구성된 네트워크의 실험 결과를 보여주려 한다. GAN은 다른 연구에 비해 실험을 수치로 보여주기 어렵기 때문에 논문에서는 DIAT, CycleGAN, IcGAN, StarGAN 이 네 가지 모델로 동일한 도메인으로 변환하여 이미지를 나열하였다.

[그림 6] StarGAN 실험결과

그림 6. 실험 결과를 보면 맨 왼쪽의 이미지가 네트워크의 입력값이고 그 오른쪽으로 나열된 이미지가 출력 결과이다. 우측에서 4개의 컬럼 [H+G, H+A, G+A, H+G+A]는 2개 이상의 도메인으로 동시에 이미지를 변환한 결과이다. DIAT 모델인 경우에는 3개의 도메인으로 변환할 때 이미지가 뭉개지는 것이보인다. 논문에서는 양적인 평가를 위해서 설문조사를 통해 어느 모델의 결과가 좋은지 조사하였는데 설문에 참여한 사람들은 높은 확률로 StarGAN을 선택하였다.

이제 여기서 우리는 StarGAN 논문의 novelty를 다시 한 번 떠올릴필요가 있다. 다른 모델은 k(k-1)개의 신경망을 사용하지만, StarGAN은 단일 신경망(Unified Model)으로 기존 모델보다 더 나은 결과를 보여준다.

결론

본 포스팅에서는 StarGAN이 해결하고자 한 문제와 접근 방법을 소개하였고 손실 함수를 정리하였다. 논문의 저자는 자신의 깃허브에 StarGAN을 PyTorch로 구현한 코드를 공개했다[7]. 이 포스팅에서 소개한 내용만 가지고도 코드를 이해하는데 큰 어려움은 없기 때문에 더 자세한 내용이 궁금한 사람들은 저자의 깃허브를 참고하면 좋을 것 같다. 서론에 기술했듯이 올 해 CVPR 2020에 StarGAN-v2가 게재되었는데 이어지는 다음 AI 세션에서는 StarGAN-v2를 정리해보려 한다.

레퍼런스

[1] StarGAN : Unified Generative Adversarial Networks for Multi-Domain Image-to-Image Translation
[2] CycleGAN : Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks
[3] UNIT : Unsupervised image-to-image translation networks
[4] Wasserstein-GAN (유튜브: 룩팍-베셔슈타인 GAN 쉽게 이해하기)
[5] CelebA, CelebFaces Attributes Dataset
[6] RaFD, Radbound Faces Database
[7] Pytorch Implementation of StarGAN

--

--

scalalang2
CURG
Writer for

평범한 프로그래머입니다. 취미 논문 찾아보기, 코딩 컨테스트, 유니티 등 / Twitter @scalalang2 / AtCoder @scalalang