Group Study (2021-2022)/Machine Learning (GAN)

[Machine Learning] 6주차 스터디-조건부 GAN

ru_urzlo 2021. 11. 22. 01:38

✨ 목표

이미지를 단일한 클래스로 고정한 채로 다양한 이미지를 생성할 수 있게 하는 것
ex) 🙋‍♀️ 개발자: 숫자 3을 표현하는 다양한 이미지를 생성해줘!
   💻 GAN: OK!


❓ 조건부 GAN 구조

1. 생성기에 임의의 시드와 함께 어떤 이미지를 원하는지 입력을 넣어주어야 함.
2. 판별기 : 생성된 이미지와 실제 이미지 구별 ▶ 클래스 레이블과 이미지 사이의 관계 학습
     ▶ 판별기에도 클래스 레이블에 대한 정보를 같이 제공해야 함!

주요한 차이점 : 판별기와 생성기 모두 이미지 데이터 외에도 클래스 레이블을 추가로 입력받음


📌 판별기

이미지 픽셀 데이터와 클래스 레이블 정보를 동시에 받도록 판별기를 업데이트 해야함.
방법 : forward() 함수에서 이미지 텐서와 레이블 텐서를 동시에 받게 하고 결합한다.

🔻 판별기 함수

class Discriminator(nn.Module):
    
    def __init__(self):
        super().__init__()
        
        # 두 텐서를 이은 길이 = 784(이미지 텐서의 길이) + 10(레이블 텐서의 길이)
        self.model = nn.Sequential(
            nn.Linear(784+10, 200),
            nn.LeakyReLU(0.02),

            nn.LayerNorm(200),

            nn.Linear(200, 1),
            nn.Sigmoid()
        )
        
        self.loss_function = nn.BCELoss()

        self.optimiser = torch.optim.Adam(self.parameters(), lr=0.0001)

        self.counter = 0;
        self.progress = []

        pass
    
    # ----- forward() 함수
    # 레이블 텐서: 원핫 인코딩되어있는 텐서
    # torch.cat(): 하나의 텐서를 다른 텐서에 잇는 역할
    def forward(self, image_tensor, label_tensor):
        # 시드와 레이블 결합
        inputs = torch.cat((image_tensor, label_tensor))
        return self.model(inputs)
    
    
    def train(self, inputs, label_tensor, targets):
        # 신경망 출력 계산
        # forward()를 호출할 때 레이블 추가
        outputs = self.forward(inputs, label_tensor)
        
        loss = self.loss_function(outputs, targets)

        self.counter += 1;
        if (self.counter % 10 == 0):
            self.progress.append(loss.item())
            pass
        if (self.counter % 10000 == 0):
            print("counter = ", self.counter)
            pass

        self.optimiser.zero_grad()
        loss.backward()
        self.optimiser.step()

        pass
    
    
    def plot_progress(self):
        df = pandas.DataFrame(self.progress, columns=['loss'])
        df.plot(ylim=(0), figsize=(16,8), alpha=0.1, marker='.', grid=True, yticks=(0, 0.25, 0.5, 1.0, 5.0))
        pass
    
    pass

🔻훈련 반복문

# 레이블 텐서를 추가로 train() 함수에 전달
for label, image_data_tensor, label_tensor in mnist_dataset:
    # 실제 데이터
    D.train(image_data_tensor, label_tensor, torch.FloatTensor([1.0]))
    # 생성 데이터
    D.train(generate_random_image(784), generate_random_one_hot(10), torch.FloatTensor([0.0]))
    pass

🔻generate_random_one_hot(): 임의의 원핫 인코딩된 클래스 레이블 벡터

# 크기는 정수로 지정해야 함
def generate_random_one_hot(size):
    label_tensor = torch.zeros((size))
    random_idx = random.randint(0,size-1)
    label_tensor[random_idx] = 1.0
    return label_tensor

✅ 손실 확인

손실 차트는 크게 변한 게 없고, 기존 판별기와 거의 똑같음.


📌 생성기

시드와 레이블 텐서를 생성기에 투입하게 햇으므로, 두 텐서를 결합해서 신경망에 전달하게 수정해야함.

🔻 생성기 함수

class Generator(nn.Module):
    
    def __init__(self):
        super().__init__()
        
        self.model = nn.Sequential(
            nn.Linear(100+10, 200),
            nn.LeakyReLU(0.02),

            nn.LayerNorm(200),

            nn.Linear(200, 784),
            nn.Sigmoid()
        )
        
        self.optimiser = torch.optim.Adam(self.parameters(), lr=0.0001)

        self.counter = 0;
        self.progress = []
        
        pass
    
    
    def forward(self, seed_tensor, label_tensor):        
        # 시드와 레이블 결합
        inputs = torch.cat((seed_tensor, label_tensor))
        return self.model(inputs)


	# 레이블 텐서를 받도록 수정
    def train(self, D, inputs, label_tensor, targets):
        # 신경망 출력 계산
        g_output = self.forward(inputs, label_tensor)
        
        # 판별기로 전달
        # 생성기에서 생성된 이미지들을 판별기의 forward()함수에 넘김
    	# --> 생성기가 다른 레이블로 잘못 판단하는 것을 방지
        d_output = D.forward(g_output, label_tensor)
        
        loss = D.loss_function(d_output, targets)

        self.counter += 1;
        if (self.counter % 10 == 0):
            self.progress.append(loss.item())
            pass

        self.optimiser.zero_grad()
        loss.backward()
        self.optimiser.step()

        pass
    
    def plot_images(self, label):
        label_tensor = torch.zeros((10))
        label_tensor[label] = 1.0
        # 2행 3열로 샘플 이미지 출력
        f, axarr = plt.subplots(2,3, figsize=(16,8))
        for i in range(2):
            for j in range(3):
                axarr[i,j].imshow(G.forward(generate_random_seed(100), label_tensor).detach().cpu().numpy().reshape(28,28), interpolation='none', cmap='Blues')
                pass
            pass
        pass
    
    def plot_progress(self):
        df = pandas.DataFrame(self.progress, columns=['loss'])
        df.plot(ylim=(0), figsize=(16,8), alpha=0.1, marker='.', grid=True, yticks=(0, 0.25, 0.5, 1.0, 5.0))
        pass
    
    pass

📌 훈련 반복문

레이블 텐서를 판별기와 생성기에 전달해야한다.

🔻 훈련 반복문 에포크 반복문 내부 코드

for epoch in range(epochs):
  print ("epoch = ", epoch + 1)

  for label, image_data_tensor, label_tensor in mnist_dataset:
    # 참에 대해 판별기 훈련
    D.train(image_data_tensor, label_tensor, torch.FloatTensor([1.0]))

    # 임의의 원핫 인코딩된 값을 레이블로 이용
    random_label = generate_random_one_hot(10)

    # 거짓에 대해 판별기 훈련
    # G의 기울기가 계산되지 않도록 detach() 함수를 이용
    D.train(G.forward(generate_random_seed(100), random_label).detach(), random_label, torch.FloatTensor([0.0]))
    
    # 임의의 원핫 인코딩된 값을 레이블로 이용
    # 판별기와 생성기 모두에 임의의 같은 레이블 텐서 투입
    random_label = generate_random_one_hot(10)

    # 생성기 훈련
    G.train(D, generate_random_seed(100), random_label, torch.FloatTensor([1.0]))

    pass
    
  pass

📌 차트 그리기

label을 정수로 받아서, 이로부터 원핫 인코딩된 텐서를 만들고 생성기에 전달.
여섯 개의 다른 임의의 시드로 여섯 개의 이미지가 생성되어
최종적으로 격자에 그려짐.

🔻 생성기 클래스에 plot_images() 메서드 추가

    def plot_images(self, label):
        label_tensor = torch.zeros((10))
        label_tensor[label] = 1.0
        # 2행 3열로 샘플 이미지 출력
        f, axarr = plt.subplots(2,3, figsize=(16,8))
        for i in range(2):
            for j in range(3):
                axarr[i,j].imshow(G.forward(generate_random_seed(100), label_tensor).detach().cpu().numpy().reshape(28,28), interpolation='none', cmap='Blues')
                pass
            pass
        pass

💡 조건부 GAN 결과 확인하기

1️⃣ 판별기 손실값

 

2️⃣ 생성기 손실값

 

3️⃣ 숫자 5를 의미하는 이미지를 그리게하기


🌟 핵심 정리

  • 조건부 GAN은 원하는 클래스의 데이터 생성이 가능하다.
  • 훈련 시 판별기에 이미지를 보강해서 전달하며, 생성기에는 클래스 레이블을 통해 시드가 투입되어야 한다.
  • 조건부 GAN은 레이블 정보를 받지 않는 일반적인 GAN보다 좋은 이미지 품질의 데이터를 생산함.