[GAN] Generative Adversarial Nets 논문 공부

2023. 2. 19. 21:13인공지능(AI)

생성 모델 중 하나인 GAN에 대해 알아보자.

 

 

Two-Player Minimax Game

Counterfeit vs Police

저자는 Generator를 Counterfeit $($위조지폐범$)$에, Discriminator를 Police $($경찰$)$에 비유한다.

 

위조지폐범이 위조지폐를 만들면 경찰은 그 지폐가 진짜인지 가짜인지 판별하게 된다. 처음에는 기술이 부족한 위조지폐범이 엉터리 지폐를 만들어 쉽게 가짜라는 것을 들킬 것이다. 하지만 도둑도 학습을 잘 하고 있다면 점차 진짜같은 돈을 만들어 경찰도 진짜 지폐인지 가짜 지폐인지 구분하기 힘든 수준까지 오게 될 것이다.

 

이처럼 Discriminator에게 real image를 보여주면서 진짜와 가짜를 판단할 수 있도록 학습시키고, Generator는 fake image를 real image처럼$($real image분포를 따르도록$)$ 생성하여 Discriminator를 속이도록 학습한다. 말 그대로 Generator와 Discriminator 간 적대적 학습이라고 할 수 있다.

 

value function V

$$_{\ G}^{min}\ _{\ D}^{max}\ V(D, G)\ =\ \mathbb{E}_{x\sim p_{data}(x)}[log D(x)]\ +\ \mathbb{E}_{z\sim p_{z}(z)}[log (1-D(G(z)))]$$

 

수식에서 ${x\sim p_{data}(x)}$는 실제 data분포에서 image x를 뽑는다는 의미이고, $\mathbb{E}$는 기댓값을 의미한다. Discriminator D는 value function V를 최대화 하는 방향으로 학습하게 되고, 반대로 Generator G는 최소화 하는 방향으로 학습하게 된다. Generator G의 경우 $\mathbb{E}_{x\sim p_{data}(x)}[log D(x)]$에는 관여할 수 없기 때문에 $\mathbb{E}_{z\sim p_{z}(z)}[log (1-D(G(z)))]$부분만 사용한다.

 

위 식에서 보면 Generator는 $log (1-D(G(z)))$를 최소화 하는 방향으로 해야하지만, 실제 학습에서는 $log (D(G(z)))$를 최대화 하는 방식으로 학습한다.

 

 

$log (1-D(G(z))$를 그래프로 나타내면 위와 같이 나타낼 수 있다. 이런 식의 loss function을 갖게될 때, 학습 초기에 Generator는 Discriminator가 가짜라고 판단하기 쉬운 이미지를 생성하는데, 0에 가까운 부분에서 함수의 기울기가 완만하기 때문에 학습이 더딜 수 있다.

 

 

그래서 $log (D(G(z))$를 사용하여 위와 같은 그래프로 나타낼 수 있는데, 학습 초기에 큰 기울기를 통해 빠르게 학습할 수 있다는 장점이 있다.

 

 

 

검은 점선은 실제 데이터 분포를 나타낸 것이고, 녹색은 Generator가 생성한 이미지 분포이다. 하단에 검은 선은 Generator는 입력으로 랜덤으로 생성한 noise vector z를 주게 되는데 Generator가 z를 입력으로 받아 생성한 x와 매칭시킨다는 의미같다. 

 

$(a)$부터 시작해서 학습을 하다보면 Generator가 만들어 내는 데이터 분포가 실제 데이터 분포를 따라가게 된다. 결국 분포가 유사해지게 되면 Discriminator가 특정 이미지가 들어왔을 때 진짜인지 가짜인지 $\frac{1}{2}$의 확률로 판단하게 된다. 즉, 진짜 이미지나 가짜이미지가 들어와도 이를 구분할 수 없게 된다.

 

Generator가 실제 데이터 분포를 완전히 따라갔을 경우 $G(z) = x$를 만들어 낼 것이고, $D_{loss} = log D(x) +log (1-D(x))$로 정리하면 위의 그래프처럼 나와 $D(x) = \frac{1}{2}$에서 최소가 된다. 

 

 

 

MNIST

직접 학습시키면서 뽑아본 결과다. MNIST dataset에 대해서는 학습을 여러번 거칠 수록 알아 볼 만한 이미지가 생성된다.

 

Mode Collapse 문제

CelebA

CelebA dataset으로 훈련시켜 봤는데 MNIST에 비해 특징이 더 복잡하고 RGB color에 size도 더 크다보니 학습이 느리기도 하고 결과도 안 좋다. 또한 비슷한 결과만 출력하는 듯 보이는데, 특정 이미지 분포에 대해서만 편향되어 학습되는 mode collapse 문제로 볼 수 있을 것 같다.

 

reference

I. Goodfellow, J. Pouget-Abadie, M. Mirza, B. Xu, D. Warde-Farley, S. Ozair, A. Courville, and Y. Bengio. Generative Adversarial Networks. In NIPS, 2014.

'인공지능(AI)' 카테고리의 다른 글

Deep Fusion (작성중)  (0) 2023.09.11
Yolo 관련 공부  (0) 2023.08.19
Auxiliary Learning  (0) 2023.08.12
[CGAN] Conditional Generative adversarial Nets 논문 공부  (0) 2023.02.24