This recipe builds an autoencoder for compressing the number of features in the MNIST handwritten digits dataset.
For building an autoencoder, three components are used in this recipe :
- an encoding function,
- a decoding function,
- a loss function between the amount of information loss between the compressed representation of your data and the decompressed representation.
What is an Auto encoder ?
An auto encoder is used to encode features so that it takes up much less storage space but effectively represents the same data. It is a type of neural network that learns efficient data codings in an unsupervised way. The aim of an autoencoder is to learn a representation for a dataset, for dimensionality reduction, by ignoring signal "noise".
What is Unsupervised learning model ?
Unsupervised machine learning models infer patterns from a dataset without reference to known, or labeled, outcomes. Unlike supervised machine learning, unsupervised machine learning methods cannot be directly applied to a problem because you have no idea what the values for the output data might be, making it impossible for you to train the algorithm the way you normally would. Unsupervised learning can instead be used for discovering the underlying structure of the data.
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
import numpy as np
%matplotlib inline
torch.manual_seed(1) # reproducible
# Hyper Parameters
EPOCH = 10
BATCH_SIZE = 64
LR = 0.005 # learning rate
DOWNLOAD_MNIST = False
N_TEST_IMG = 5
# Mnist digits dataset
train_data = torchvision.datasets.MNIST(
root='./mnist/',
train=True,
# this is training data
transform=torchvision.transforms.ToTensor(),
# Converts a PIL.Image or numpy.ndarray to
# torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0]
download=DOWNLOAD_MNIST,
# download it if you don't have it
)
# plot one example
print(train_data.train_data.size()) # (60000, 28, 28)
print(train_data.train_labels.size()) # (60000)
plt.imshow(train_data.train_data[2].numpy(), cmap='gray')
plt.title('%i' % train_data.train_labels[2])
plt.show()
# Data Loader for easy mini-batch return in training, the image batch shape will be (50, 1, 28, 28)
train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
class AutoEncoder(nn.Module):
def __init__(self):
super(AutoEncoder, self).__init__()
self.encoder = nn.Sequential(
nn.Linear(28*28, 128),
nn.Tanh(),
nn.Linear(128, 64),
nn.Tanh(),
nn.Linear(64, 12),
nn.Tanh(),
nn.Linear(12, 3), # compress to 3 features which can be visualized in plt
)
self.decoder = nn.Sequential(
nn.Linear(3, 12),
nn.Tanh(),
nn.Linear(12, 64),
nn.Tanh(),
nn.Linear(64, 128),
nn.Tanh(),
nn.Linear(128, 28*28),
nn.Sigmoid(), # compress to a range (0, 1)
)
def forward(self, x):
encoded = self.encoder(x)
decoded = self.decoder(encoded)
return encoded, decoded
autoencoder = AutoEncoder()
print(autoencoder)
optimizer = torch.optim.Adam(autoencoder.parameters(), lr=LR)
loss_func = nn.MSELoss()
# original data (first row) for viewing
view_data = Variable(train_data.train_data[:N_TEST_IMG].view(-1, 28*28).type(torch.FloatTensor)/255.)
for epoch in range(EPOCH):
for step, (x, y) in enumerate(train_loader):
b_x = Variable(x.view(-1, 28*28)) # batch x, shape (batch, 28*28)
b_y = Variable(x.view(-1, 28*28)) # batch y, shape (batch, 28*28)
b_label = Variable(y) # batch label
encoded, decoded = autoencoder(b_x)
loss = loss_func(decoded, b_y) # mean square error
optimizer.zero_grad() # clear gradients for this training step
loss.backward() # backpropagation, compute gradients
optimizer.step() # apply gradients
if step % 500 == 0 and epoch in [0, 5, EPOCH-1]:
print('Epoch: ', epoch, '| train loss: %.4f' % loss.data[0])
# plotting decoded image (second row)
_, decoded_data = autoencoder(view_data)
# initialize figure
f, a = plt.subplots(2, N_TEST_IMG, figsize=(5, 2))
for i in range(N_TEST_IMG):
a[0][i].imshow(np.reshape(view_data.data.numpy()[i],
(28, 28)), cmap='gray');
a[0][i].set_xticks(()); a[0][i].set_yticks(())
for i in range(N_TEST_IMG):
a[1][i].clear()
a[1][i].imshow(np.reshape(decoded_data.data.numpy()[i],
(28, 28)), cmap='gray')
a[1][i].set_xticks(()); a[1][i].set_yticks(())
plt.show(); #plt.pause(0.05)