Variational AutoEncoder(VAE, Multi-VAE)

2024. 2. 21. 14:40ML model/Generative Model

이전에 학습한 AE, DAE, VAE의 차이점 도식화

 

VAE의 특징은 z(latent vector)생성 시, linear layer를 통과하여 나온 벡터 내의 값(스칼라) 뿐만 아니라 해당 값의 평균과 분산 정보까지 포함시키는 것이다.

    - 이때 input으로 제공된 모든 값은 정규분포를 따른다고 가정한다.

 

 

VAE

VAE 구조

 

구현방법에는 여러가지가 있을 수 있지만 간단하게 두 가지를 소개하려고 한다.

1) encoder에 의해 생성되는 z의 크기를 두배로 하여 mu, logvar로 나누는 방법

2) z를 생성하는 linear layer를 두개 설정하여 mu,logvar를 각각 생성하는 방법.

 

필자는 무비렌즈 데이터로 영화 추천이라는 테스크를 수행했는데 1번 방법의 성능이 훨씬 좋았다.

import torch.nn as nn
import torch.nn.functional as F
import torch
import numpy as np

class VAE(nn.Module):
    """
    input : encoder_dims

        encoder = [n_items, dim1, dim2]
        decoder_dims = [dim2, dim1, n_items]
    """
    
    def __init__(self, config, encoder_dims, decoder_dims=None, dropout_rate=0.5): 
        super(VAE, self).__init__()
        self.config = config
        
        self.encoder_dims = encoder_dims
        if decoder_dims:
            assert decoder_dims[0] == encoder_dims[-1], "In and Out dimensions must equal to each other"
            assert decoder_dims[-1] == encoder_dims[0], "Latent dimension for p- and q- network mismatches."
            self.decoder_dims = decoder_dims
        else:
            self.decoder_dims = encoder_dims[::-1]
        
        self.encoder = self.build_layers(self.encoder_dims, config['activate_function'], module='encoder')
        self.decoder = self.build_layers(self.decoder_dims, config['activate_function'], module='decoder')
        
        self.drop = nn.Dropout(dropout_rate)
        
    def build_layers(self, dims, activate_function = 'ReLU', module=None):
        """
        Helper function to build layers based on the provided dimensions.

        Parameters:
            - dims (list): List of dimensions for the layers.

        Returns:
            - nn.Sequential: Sequential container for the layers.
        """
        layers = []
        if module == 'encoder':
            dims[-1] = dims[-1]*2 #last dims for mu and logvar.
        for i in range(1, len(dims)):
            layers.append(nn.Linear(dims[i-1], dims[i]))
            if i == len(dims)-1:
                pass
            else:
                if activate_function == 'ReLU':
                    layers.append(nn.ReLU())
                elif activate_function == 'Sigmoid':
                    layers.append(nn.Sigmoid())
                elif activate_function == 'Tanh':
                    layers.append(nn.Tanh())
                else:
                    pass

        return nn.Sequential(*layers)
        
    def forward(self, input):
        if self.config['denoising'] == 'Dropout':
            h = F.normalize(input)
            h = self.drop(h)
            
        elif self.config['denoising'] == 'Gaussian':
            h = self.add_noise(input)
            h = F.normalize(h)

        h = self.encoder(h)
        h, mu, logvar = self.reparameterize(h)
        h = self.decoder(h)
        
        return h, mu, logvar
  
    def get_codes(self, x):
        return self.encoder(x)
    
    def reparameterize(self, h):
        '''
        make mu, logvar using h(vector) linear layer
            Args: 
                decoder_dims -> list
                h : encoder last output -> vector(hidden dim)
            
            Return:
                z : mu + (std*eps) has same dims with encoder output and decoder input
        '''
        mu, logvar = h[:, :h.shape[1]//2], h[:, h.shape[1]//2:]
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        
        z = mu + (std*eps)
        return z, mu, logvar

 

 

VAE, Multi-VAE

간단하게는 loss function의 차이라고 생각하면 된다. 

- MSE를 사용하면 VAE

- BCE를 사용하면 Logistic-VAE

- CrossEntropy(Multi-Class)를 사용하면 Multi-VAE (Multi는 결국 Multinomial에서 온 말로 다항분포를 의미한다 = Multi-Class)

 

def vae_loss_function(recon_x, x, mu, logvar):
    loss = criterion(recon_mat, mat) #criterion is nn.CrossEntropyLoss()
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    return loss+KLD

 

'ML model > Generative Model' 카테고리의 다른 글

Denoising AutoEncoder(DAE, Multi-DAE)  (0) 2024.02.19
AutoEncoder (for Recommend System)  (1) 2024.02.19