帶你認識Vector-Quantized Variational AutoEncoder — Pytorch實作篇

Tan
Taiwan AI Academy
Published in
23 min readJul 16, 2020

上一篇大致上簡介了VQ-VAE的模型架構與訓練方法,在這邊我們實際來建立一個VQ-VQE模型。本次參考了此位MishaLaskin的github實踐,使用到的框架是pytorch,由Facebook的人工智慧研究團隊所主導開發,是近年來除了google的tensorflow之外數一數二熱門的深度學習框架,程式碼易讀性很高,在使用上也非常有彈性。廢話不多說,就讓我們來看一下程式碼的部分吧。

註1:如果你對VQ-VAE的概念還不熟的話仍然建議先了解一下整個模型架構以及損失函數的部分

註2:medium的程式碼在排版上比較困難,需要看比較簡潔易懂的排版可以參考github的程式碼

載入套件與資料

實作的起手式大致雷同,一開始同樣是載入套件的部分(若您環境未有此套件再請是先自行安裝),另外為了展示方便,我們使用pytorch內建的CIFAR10資料集做示範。

from __future__ import print_functionimport matplotlib.pyplot as plt
import numpy as np
from six.moves import xrange
import umap
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchsummary import summary
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torchvision.utils import make_grid

另外在pytorch中我們可以使用以下的指令建立一個torch.device類型的變數決定要使用cpu還是gpu做運算

device = torch.device(“cuda” if torch.cuda.is_available() else “cpu”)

接著使用以下的程式碼讀取資料,這邊pytorch已經幫我們切好訓練集與測試集(透過train參數做控制)。前處理的部分除了轉換為tensor型態外我們單純做了標準化(mean=0.5, std=1.0)的步驟,對應的程式碼部分式transform參數後的內容。

training_data = datasets.CIFAR10(
root=’data’, train=True, download=True,
transform=transforms.Compose(
[transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),
(1.0,1.0,1.0))]))
validation_data = datasets.CIFAR10(
root=’data’, train=False, download=True,
transform=transforms.Compose(
[transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),
(1.0,1.0,1.0))]))

這邊也預先計算資料的變異程度並命名為data_variance變數,當作後續損失函數的其中一個 scaling factor (但個人認為影響不大)

data_variance = np.var(training_data.data / 255.0)

建立模型

讀取資料結束後,我們就可以開始建立模型。在此先列出下面會使用到的模型與訓練時會使用到的超參數。

batch_size = 256
num_training_updates = 15000
num_hiddens = 128
num_residual_hiddens = 32
num_residual_layers = 2
embedding_dim = 64
num_embeddings = 512
commitment_cost = 0.25
decay = 0.99
learning_rate = 1e-3
Figure 1. model structure of VQ-VAE

上面的圖中雖簡單列出VQ-VAE的模型架構,然而並沒有特別列出Encoder與Decoder的細節,在此我們邊建立Encoder與Decoder的模組邊做說明。

class Residual(nn.Module):

def __init__(self,in_channels,num_hiddens,num_residual_hiddens):
super().__init__()

self._block = nn.Sequential(nn.ReLU(True),
nn.Conv2d(in_channels=in_channels,
out_channels=num_residual_hiddens,
kernel_size=3, stride=1, padding=1, bias=False),nn.ReLU(True),
nn.Conv2d(in_channels=num_residual_hiddens,
out_channels=num_hiddens, kernel_size=1, stride=1, bias=False))

def forward(self, x):
return x + self._block(x)
class ResidualStack(nn.Module):
def __init__(self,in_channels,num_hiddens,num_residual_layers,
num_residual_hiddens):
super().__init__()

self._num_residual_layers = num_residual_layers

self._layers = nn.ModuleList([
Residual(in_channels, num_hiddens, num_residual_hiddens)
for _ in range(self._num_residual_layers)])

def forward(self, x):
for i in range(self._num_residual_layers):
x = self._layers[i](x)
return F.relu(x)

上面首先建立了模型所需要使用的殘差模組(Residual Module),這邊convolution的設定與原始論文中Experiments段落提到的超參數皆相同,每個Residual block內的input皆通過 relu、3x3 convolution、relu、再與原始input相加產生output,同樣的Residual block重複兩次。

class Encoder(nn.Module):
def __init__(self,in_channels,num_hiddens,
num_residual_layers,num_residual_hiddens):
super().__init__()

self._conv_1 = nn.Conv2d(in_channels=in_channels,
out_channels=num_hiddens//2,
kernel_size=4, stride=2, padding=1)
self._conv_2 = nn.Conv2d(in_channels=num_hiddens//2,
out_channels=num_hiddens,
kernel_size=4, stride=2, padding=1)
self._conv_3 = nn.Conv2d(in_channels=num_hiddens,
out_channels=num_hiddens,
kernel_size=3, stride=1, padding=1)
self._residual_stack = ResidualStack(in_channels=num_hiddens,
num_hiddens=num_hiddens,num_residual_layers=num_residual_layers,
num_residual_hiddens=num_residual_hiddens)

def forward(self, inputs):
x = self._conv_1(inputs)
x = F.relu(x)

x = self._conv_2(x)
x = F.relu(x)

x = self._conv_3(x)
return self._residual_stack(x)

class Decoder(nn.Module):
def __init__(self, in_channels, out_channels, num_hiddens,
num_residual_layers, num_residual_hiddens):
super().__init__()

self._conv_1 = nn.Conv2d(in_channels=in_channels,
out_channels=num_hiddens,kernel_size=3, stride=1, padding=1)

self._residual_stack = ResidualStack(in_channels=num_hiddens,
num_hiddens=num_hiddens,num_residual_layers=num_residual_layers,
num_residual_hiddens=num_residual_hiddens)

self._conv_trans_1 = nn.ConvTranspose2d(in_channels=num_hiddens,
out_channels=num_hiddens//2,kernel_size=4, stride=2, padding=1)

self._conv_trans_2 = nn.ConvTranspose2d(
in_channels=num_hiddens//2,out_channels=out_channels,
kernel_size=4, stride=2, padding=1)

def forward(self, inputs):
x = self._conv_1(inputs)
x = self._residual_stack(x)
x = self._conv_trans_1(x)
x = F.relu(x)
return self._conv_trans_2(x)

接著我們定義AutoEncoder的Encoder與Decoder,在Encoder的部分將會輸入原圖並輸出hidden representation (擁有三個維度),在此將輸入圖片維度(in_channels)、卷積層內的卷積數量(num_hiddens)、Residual block數量(num_residual_layers)、以及Residual block內的卷積數量(num_residual_hiddens)皆設為參數可供設定,為避免內容過於冗長,其餘超參數如各層的filter_size再請各位直接從程式碼中閱讀。

Decoder則需要接收quantization 後再做轉換的hidden representation (同樣為三個維度),再嘗試還原為原圖,在架構上為了保持讓hidden representation能保有較多彈性,在此先將任意維度通過一個卷積層調整channel數量,再通過Residual stack層與兩層的反卷積層(transpose convolution layer),同樣需要設定輸入representation的維度(in_channels)、輸出圖片的維度(out_channels)、卷積層內的卷積數量(num_hiddens)、 Residual block數量(num_residual_layers)、以及Residual block內的卷積數量(num_residual_hiddens),同樣地其他參數請各位至程式碼中直接觀看。

若依照先前設定之模型超參數建立模型,Encoder與Decoder的Model Summary將與Figure 2. 與 Figure 3.相同。

Figure 2. Model Summary of Encoder
Figure 3. Model Summary of Decoder

Note:此篇的整體模型與原論文不同,除了Encoder與Decoder原本的層數外,後續在整合成VQ-VAE整體模型時還會Encoder後的輸出會再通過一次convolution,這是需要注意的地方。

建立好encoder與decoder後,接下來我們來處理這一篇的特點Vector Quantization部分。

class VectorQuantizer(nn.Module):
def __init__(self, num_embeddings, embedding_dim,
commitment_cost):
super(VectorQuantizer, self).__init__()

self._embedding_dim = embedding_dim
self._num_embeddings = num_embeddings
self._embedding = nn.Embedding(self._num_embeddings,
self._embedding_dim)
self._embedding.weight.data.uniform_(-1/self._num_embeddings,
1/self._num_embeddings)
self._commitment_cost = commitment_cost

def forward(self, inputs):
# convert inputs from (B, C, H, W) to (B, H, W, C)
inputs = inputs.permute(0, 2, 3, 1).contiguous()
input_shape = inputs.shape

# flatten input to 2D (B * H * W , C)
flat_input = inputs.view(-1, self._embedding_dim)

# calculate distances (euclidean)
distances = (torch.sum(flat_input**2, dim=1, keepdim=True)
+ torch.sum(self._embedding.weight**2, dim=1)
— 2 * torch.matmul(flat_input, self._embedding.weight.t()))

# Encoding
encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
encodings = torch.zeros(encoding_indices.shape[0],
self._num_embeddings, device=inputs.device)
encodings.scatter_(1, encoding_indices, 1)
# 針對axis=1的地方,將encoding_indices中的每個index位置改為1

# Quantize and unflaten
quantized = torch.matmul(encodings,
self._embedding.weight).view(input_shape)

# Loss .detach()這個method是關鍵
e_latent_loss = F.mse_loss(quantized.detach(), inputs) # detach()
q_latent_loss = F.mse_loss(quantized, inputs.detach()) # detach()
loss = q_latent_loss + self._commitment_cost * e_latent_loss

quantized = inputs + (quantized — inputs).detach()
avg_probs = torch.mean(encodings, dim=0)
perplexity = torch.exp(-torch.sum(avg_probs *
torch.log(avg_probs + 1e-10)))

# convert quantized from (B, H, W, C) to (B, C, H, W)
quantized = quantized.permute(0, 3, 1, 2)

return loss, quantized.contiguous(), perplexity, encodings
Figure 2. VQ-VAE loss

上面同樣列出VQ-VAE 的損失函數給大家做參考。VectorQuantizer類別主要將在輸入Ze(x)後進行Vector Quantization的步驟,並且輸出此部分的loss (L中第二項與第三項)、進行Vector Quantization步驟轉換後的hidden representation (用於通過decoder後還原成原資料)、目前的codebook權重的困惑度(perplexity)、以及資料的codebook 編碼。值得一提的是,在sg (stop sign)的部分pytorch提供了一個簡潔的方法.detach()讓特定的tensor不會被計算到梯度,使得我們可以簡單地實踐這個loss function,除了第二與第三項loss term使用到sg外,Straight-through estimator的概念也被很巧妙的用inputs + (quantized — inputs).detach()一個式子完成(前向傳導時輸出為quantized後的結果,然而倒傳遞時計算的是inputs的梯度)。

完成了每個部份的模組後,我們就可以把整個模型搭建起來。

class Model(nn.Module):
def __init__(self, num_hiddens, num_residual_layers,
num_residual_hiddens, num_embeddings, embedding_dim,
commitment_cost, decay=0):

super().__init__()

self._encoder = Encoder(3, num_hiddens, num_residual_layers,
num_residual_hiddens)

self._pre_vq_conv = nn.Conv2d(in_channels=num_hiddens,
out_channels=embedding_dim, kernel_size=1, stride=1)

self._vq_vae = VectorQuantizer(num_embeddings, embedding_dim,
commitment_cost)

self._decoder = Decoder(embedding_dim, 3, num_hiddens,
num_residual_layers, num_residual_hiddens)

def forward(self, x):
z = self._encoder(x)
z = self._pre_vq_conv(z)
loss, quantized, perplexity, _ = self._vq_vae(z)
x_recon = self._decoder(quantized)

return loss, x_recon, perplexity

建立model與optimizer,在此同樣使用Adam作為我們的優化器。另外,在pytorch的框架下,每個tensor都可以被快速地配置到cpu / gpu 做計算,因此在這邊我們也依據最一開始決定的device變數將模型權重配置到cpu/gpu上。

model = Model(num_hiddens, num_residual_layers, num_residual_hiddens, num_embeddings, embedding_dim, commitment_cost, decay).to(device)optimizer = optim.Adam(model.parameters(), lr=learning_rate, amsgrad=False)

模型訓練

training_loader = DataLoader(training_data, batch_size = batch_size, shuffle = True, pin_memory = True)
validation_loader = DataLoader(validation_data, batch_size = batch_size, shuffle = True, pin_memory = True)

執行以下程式碼即開始做模型訓練。在每次的更新中,我們會從training_loader讀取一個batch的資料,將其配置到cpu或gpu中(依照device變數決定),接著先初始化優化器的梯度儲存處使其為0,再做前向傳導得到輸出與損失函數大小。接著使用.backward()方法計算梯度,再使用optimizer.step()方法做權重的更新,這樣就完成了一次的訓練。最後,在每一次訓練階段我們同樣將誤差以及codebook的perplexity指標記錄下來以便觀察,訓練結束後使用torch.save()函數將模型存出。

# model.train()
train_res_recon_error = []
train_res_perplexity = []
for i in xrange(num_training_updates):
(data, _) = next(iter(training_loader))
data = data.to(device)
optimizer.zero_grad()

vq_loss, data_recon, perplexity = model(data)
recon_error = F.mse_loss(data_recon, data) / data_variance
loss = recon_error + vq_loss
loss.backward()

optimizer.step()

train_res_recon_error.append(recon_error.item())
train_res_perplexity.append(perplexity.item())

if (i+1) % 100 ==0:
print(‘{:d} iterations, recon_error : {:.3f}, perplexity: {:.3f}\r\n’.format(i+1, np.mean(train_res_recon_error[-100:]),
np.mean(train_res_perplexity[-100:])))
PATH=’saved_models/vqvae_params.pkl’
torch.save(model.state_dict(), PATH)

在訓練了15000個iteration後,Reconstruction Error約莫從0.7降到0.05。

重建圖像

在最初的影像維度縮減作業中,目標是希望在較少量的空間需求下儲存高品質的圖像,VQ-VAE能夠幫助我們將圖像轉為潛在空間的一組編碼,我們只需要儲存每張圖片的離散編碼、整本codebook、以及Decoder網路就可以得到原始的(理論上)高品質圖像。但還原的效果如何呢,我們可以使用以下程式碼作驗證資料的還原。

def show(img):
npimg = img.numpy()
fig = plt.imshow(np.transpose(npimg, (1, 2, 0)),
interpolation='nearest')
fig.axes.get_xaxis().set_visible(False)
fig.axes.get_yaxis().set_visible(False)

首先先建立畫出圖形的函數。接著我們依序使用model中的各個模組進行資料的壓縮與還原。

(valid_originals, _) = next(iter(validation_loader))
valid_originals = valid_originals.to(device)
vq_output_eval = model._pre_vq_conv(model._encoder(valid_originals))
_, valid_quantize, _, _ = model._vq_vae(vq_output_eval)
valid_reconstructions = model._decoder(valid_quantize)

以下為原始的驗證集圖片

show(make_grid(valid_originals.cpu()[:16,:,:,:]+0.5))

下方則為通過VQ-VAE再還原後的圖片

show(make_grid(valid_reconstructions.cpu().data[:16,:,:,:])+0.5, )

可以看到在整體輪廓雖皆相同,但細緻度仍然與原圖會有些差異(畢竟作了壓縮,無法無痛還原)。

給個總結

在這篇文章中我們從頭到尾建立了一個VQ-VAE模型,並在CIFAR10資料集上作了訓練與驗證,相對而言這邊使用到的模型大小並不算太小(相對於資料解析度而言),在架構上與原論文相比也較深(但壓縮的比例與Codebook的數量相同),不過結果看起來和原始論文中呈現的效果有一段差距,也許仍然是一些超參數的設定上需要再做搜尋與測試。在此篇文章中由於篇幅關係也省略了作者在原論文中提到的另一種Exponential Moving Average (EMA)的梯度更新演算法,有興趣的人也可以再去github中搜尋與參考。

參考資源

在此同樣列出其他實踐VQ-VAE的github專案,做法也都相當不錯,但在類別的設計上可能會有些差異,若大家覺得上面的程式碼寫法不夠好或想多方參考的話可以使用。

--

--