Group Study (2021-2022)/Machine Learning (GAN)

[Machine Learning] 4주차 스터디-튼튼한 GAN 만들기

ddo0 2021. 10. 31. 02:09

< 네트워크 모델을 이용해 얻고자 하는 목표/이상적인 결과 >

1. 각각 서로 다른 이미지들을 만들어내기

2. 그저 평균처럼 애매모호한 이미지가 아니라 훈련 샘플처럼 보이는 이미지들 만들기

1) 백쿼리(Backquery)

 하나의 숫자를 표현하는 원핫 인코딩 벡터를 훈련되어 있는 네트워크에 넣어 그 숫자에 맞는 이상적인 이미지를 거꾸로 만드는 것

- 같은 원핫 인코딩 벡터는 같은 결과를 출력한다.

- 레이블을 나타내는 모든 훈련 데이터의 뭔가 평균적인 이미지가 나온다.

👉🏻 백쿼리만으로는 원하고자 하는 것을 달성하지 못함. 단, 백쿼리 이용은 한다. 백쿼리 외에도 다른 것들이 필요하다.

2) GAN(Generative adversarial network)을 활용하자!

생성적 적대 신경망이라고 불리는 GAN은 판별기와 생성기, 즉 두 개의 신경망 모델이 서로 적대적 관계로 경쟁을 하며 서로를 뛰어넘기 위한 훈련을 반복해, 결과적으로 둘 다 성능이 좋아지게 되는 구조를 가리킨다.

판별기

📍 훈련 데이터셋을 진짜라고 예측하도록 훈련하는 모델

생성기

📍 판별기가 분류할 이미지를 생성하는 역할을 하며, 생성한 가짜 이미지에 대한 판별기의 결괏값이 1.0이 되는 것을 목표로 하는 모델이다.

1. 사람이 자세한 사항을 설정하지 않아도 된다.

    Ex) 진짜 이미지 판별하기 위해 어떤 손실함수 사용할 건지, 어떤 방식을 따를 건지

2. 경쟁을 통해 발전한다.

    판별기는 진짜를 가려내기 위해, 생성기는 판별기가 속을 수 있을 만한 가짜를 만들기 위해 성능이 개선될 것이고, 서로가 경쟁하며 발전하는 경우이다. 이는 굉장히 영리한 접근법이라고 볼 수 있다.

GAN 훈련 방법

*주의사항*

- 판별기와 생성기 모두를 "동시에" 훈련시켜야 한다.

- 판별기와 생성기 모두 "비슷한 수준"으로 훈련이 이루어지도록 해야 한다.

GAN 훈련 단계

1단계 - 판별기에 실제 데이터를 보여주고 출력값이 1.0이라는 값이어야 한다고 알려준다.

👉🏻 이 과정에서 발생하는 오차는 판별기를 업데이트 하는 데 사용한다.

2단계 - 생성기로부터 만들어진 가짜 데이터를 판별기에 보여주고 출력값이 0.0이어야 한다고 알려준다.

👉🏻 이 과정에서 발생하는 오차는 판별기를 업데이트 하는 데 사용한다.(생성기를 업데이트하지 않도록 조심해야 한다.)

3단계 - 판별기에 생성기의 결과를 보여주고, 생성기에 결과가 1.0이어야 한다고 알려준다.

👉🏻 이 과정에서 발생하는 오차는 생성기만을 업데이터하는 데 사용된다.

📌 3단계에서 판별기를 업데이트하지 않는 이유 = 판별기가 분류를 잘못하게끔 만들고 싶지 않기 때문이다.

GAN 훈련 시 주의사항(PATTERN 1010 통해서 이해하기)

📎 단계에 맞춰 훈련하기

1. 훈련 데이터셋을 미리 살펴보기

2. 판별기가 적어도 실제 데이터, 그리고 임의의 노이즈를 구별할 수 있는 성능을 가지는지 확인하기

    👉🏻 random을 import 하고 임의의 노이즈 패턴을 반환하는 함수를 만든 후, 훈련을 진행하고 진짜 데이터와 임의의 노이즈를 판별기에 투입해 결과를 확인한다.

- 패턴 1010의 경우에서 임의의 노이즈를 반환하는 함수는 다음과 같다.

def generate_random(size):
	random_data = torch.rand(size)
    return random_data

3. 훈련되지 않은 생성기가 올바른 형태의 진짜인 것 같은 가짜 데이터를 만드는지 확인하기

    👉🏻 생성기 class를 만들고 객체를 생성한 다음, 0.5라는 단일값을 가지는 텐서를 인자로 준 결괏값을 확인한다.

4. 손실의 변화 시각화하기

image_list = []

# 반복문으로 훈련 시 1000회마다 생성기의 결과를 image_list에 저장한다.
if (i%1000 = 0):
	image_list.append(G.forward(torch.FloatTensor([0.5])).detach().numpy())

# 훈련 동안 생성기가 생성한 이미지들을 바탕으로 생성기의 변화를 보여준다.
plt.figure(figsize = (16,8))
plt.imshow(numpy.array(image_list).T, interpolation='none', cmap='Blues')

- 다른 방법은 아래 3번 클립 참고하기

📎 잘 훈련된 GAN의 출력은 약 0.5로, 평균제곱오차(MSE)의 이상적인 값은 0.25임을 기억하기

       👉🏻 평균제곱오차 MSELoss() 손실함수는 언제나 0과 1 사이의 값을 가진다.

📎 생성기와 판별기를 각각 시각화하기

    생성기의 손실은 생성된 데이터로부터 발생한 판별기의 손실이다. 

    👉🏻  pandas와 matplotlib.pyplot as plt를 import 하고, 판별기와 생성기 클래스에 시각화할 수 있는 함수를 생성해 이를 호출한다.

import matplotlib.pyplot as plt

class Discriminator():
	def plot_progress(self):
    	df = pandas.DataFrame(self.progressm columns=['loss'])
        df.plot(ylim=(0, 1.0), figsize=(16, 8), alpha=0.1m marker='.', grid=True, yticks=(0, 0.25, 0.5))
        pass

class Generator():
	def plot_progress(self):
    	df = pandas.DataFrame(self.progressm columns=['loss'])
        df.plot(ylim=(0, 1.0), figsize=(16, 8), alpha=0.1m marker='.', grid=True, yticks=(0, 0.25, 0.5))
        pass
        
D = Discriminator()
G = Generator()

D.plot_progress()
G.plot_progress()

    👉🏻 생성기가 어떻게 변화했는지 보여주는 다른 방법도 존재한다. (클립 1의 4번 참고하기)

모드 붕괴 발생

모드 붕괴란?

- 다른 선택지가 있음에도 불구하고 생성기가 하나의 클래스로만 숫자를 생성하는 현상을 이야기한다. 발생 원인이나 해결 방법에 대해 아직까지 연구가 활발히 일어나고 있는 분야이기도 하다. 명확한 해결을 할 수는 없겠지만, 개선할 수 있는 방법 2가지는 다음과 같다.

개선 방법

1. GAN 훈련 성능을 향상시킨다.

2. 시드를 활용한다.

🔺 GAN 훈련 성능 향상시키기

*기억해야 할 것* 훈련에서 중요한 것은 양보다 질이다.

✔️ 평균제곱오차 MSELoss() 손실함수 대신 이진교차 엔트로피 BCELoss() 손실함수를 이용한다.

    - BCELoss()의 경우, 손실의 이상적인 값은 ln (2), 즉, 0.69이다.

✔️ Sigmoid() 활성화 함수 대신 LeakyReLU() 활성화 함수를 사용한다.

    - 단, BCE 손실함수 사용하는 경우는 무조건 맨 마지막 레이어에 Sigmoid() 함수를 사용한다.

      👉🏻 BCE 손실함수는 0~1 사이 외의 값을 받을 수 없는데, LeakyReLU 함수는 0~1 외의 값을 내보낼 수 있기 때문이다.

    - LeakyReLU 함수를 사용할 때 기울기는 사용자가 임의로 설정한다.

✔️ SGD 옵티마이저 대신 Adam 옵티마이저를 사용한다.

✔️ LayerNorm()을 통해 정규화를 진행해 평균을 0으로 맞추고 분산을 제한한다.

    - 극단적인 값을 피할 수 있는 방법이다.

🔺 시드로 실험하기

    1. 이미지를 만들어낼 수 있는 두 개의 시드 seed1, seed2를 생성한다.

    2-1. 두 개의 시드 사이에 일정한 간격으로 새로운 중간 시드를 생성한다.

    2-2. 두 개의 시드 합을 통해 새로운 시드 seed3을 생성한다.

            👉🏻 5와 3을 보여주는 이미지를 생성한 두 개의 시드의 덧셈 계산을 통해 새로운 시드를 만들면, 5와 3이 합쳐진 것 같은 모습이 생성된다.

    2-3. 두 개의 시드 차를 통해 새로운 시드 seed4를 생성한다.

            👉🏻 덧셈과 달리, 뺄셈 계산을 통해 새로운 시드를 만들면 딱 보기에 연관성이 없는, 새로운 이미지가 생성된다. 이는 시드끼리의 단순 계산을 통해 생성된 이미지는 생각보다 복잡한 논리로 돌아가고 있다는 사실을 보여준다.

GAN과 컬러 이미지, 그리고 데이터셋

📍 단색 이미지가 아닌 컬러 이미지를 활용할 때는 3차원의 행렬을 사용한다. 이때, 세 번째 차원의 값은 항상 3이 된다.(RGB)

📍 데이터셋의 파일을 활용할 때, 각각의 파일을 모두 열고 닫으면 상당히 '비효율적'이다. 따라서, 랜덤 액세스를 할 때는 데이터를 리패키징하여 발전된 형식을 활용하는 것이 좋은데, 이때 계층적 데이터 형식인 HDF를 이용하면 좋다.

    👉🏻 계층적이라는 말은 하나 이상의 그룹을 가질 수 있고, 그룹 안에 여러 개의 데이터셋이 포함될 수 있으며 그룹 안에 그룹이 존재할 수 있다는 사실을 알려준다. 이러한 특성은 우리가 평소에 사용하는 폴더 구조와 유사함을 보여준다.

📍 압축 파일(.zip)을 활용할 때는 zipfile을 import 해서 사용한다.

📍 GAN은 훈련 데이터를 기억하거나 단순히 복사 붙여넣기를 하지 않고, 훈련 데이터의 확률 분포를 파악하고 이를 재현한 데이터를 생성하기 위해 노력한다.

    👉🏻 데이터를 저장할 때는 drive 안 같은 계층에 존재하도록 한다.