How to train an auto encoder with tf

This recipe helps you train an auto encoder with tf

Recipe Objective

How to train an auto encoder with tf?

The autoencoder in tensorflow is a special type of neural network where it is trained to its input gets copied into its output. Such as if we have an image of handwritten digit the first work of the autoencoder is to encode the image into lower dimension which is a latent representation and then decode that latent representation back to the image. While minimizing the reconstruction error the autoencoder learns to compress the data.

Complete Guide to Tensorflow for Deep Learning with Python for Free

Step 1 - Import library

import matplotlib.pyplot as plt import numpy as np import pandas as pd import tensorflow as tf from sklearn.metrics import accuracy_score, precision_score, recall_score from sklearn.model_selection import train_test_split from tensorflow.keras import layers, losses from tensorflow.keras.datasets import fashion_mnist from tensorflow.keras.models import Model

Step 2 - Load Data

(x_train_data, _), (x_test_data, _) = fashion_mnist.load_data() x_train_data = x_train_data.astype('float32') / 255. x_test_data = x_test_data.astype('float32') / 255. print (x_train_data.shape) print (x_test_data.shape)

(60000, 28, 28)
(10000, 28, 28)

Step 3 - Define Autoencoder

dimension_lat = 64 class Autoencoder(Model): def __init__(self, dimension_lat): super(Autoencoder, self).__init__() self.dimension_lat = dimension_lat self.encoder = tf.keras.Sequential([ layers.Flatten(), layers.Dense(latent_dim, activation='relu'), ])

Here we are defining the autoencoder with 2 dence layers, the encoder here we are defining will compresses the images into 64 dimensional latent vector and then the decoder will reconstruct the original image from a latent space.

Step 4 - Compile Autoencoder

autoencoder.compile(optimizer='adam', loss=losses.MeanSquaredError())

Step 5 - Train the data

autoencoder.fit(x_train_data, x_train_data, epochs=10, shuffle=True, validation_data=(x_test_data, x_test_data))

Epoch 1/10
1875/1875 [==============================] - 6s 3ms/step - loss: 0.0088 - val_loss: 0.0089
Epoch 2/10
1875/1875 [==============================] - 6s 3ms/step - loss: 0.0087 - val_loss: 0.0089
Epoch 3/10
1875/1875 [==============================] - 6s 3ms/step - loss: 0.0087 - val_loss: 0.0088
Epoch 4/10
1875/1875 [==============================] - 5s 3ms/step - loss: 0.0087 - val_loss: 0.0088
Epoch 5/10
1875/1875 [==============================] - 5s 3ms/step - loss: 0.0087 - val_loss: 0.0088
Epoch 6/10
1875/1875 [==============================] - 5s 3ms/step - loss: 0.0086 - val_loss: 0.0088
Epoch 7/10
1875/1875 [==============================] - 5s 3ms/step - loss: 0.0086 - val_loss: 0.0088
Epoch 8/10
1875/1875 [==============================] - 4s 2ms/step - loss: 0.0086 - val_loss: 0.0087
Epoch 9/10
1875/1875 [==============================] - 4s 2ms/step - loss: 0.0086 - val_loss: 0.0088
Epoch 10/10
1875/1875 [==============================] - 4s 2ms/step - loss: 0.0086 - val_loss: 0.0087

Step 6 - Encode and decode images

images_encoded = autoencoder.encoder(x_test_data).numpy() images_decode = autoencoder.decoder(images_encoded).numpy()

Step 7 - Print Results

num = 10 plt.figure(figsize=(40, 8)) for i in range(num): ax = plt.subplot(2, num, i + 1) ## this will display original images plt.imshow(x_test_data[i]) plt.title("Autoencoded Original images") plt.gray() ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) ##---------------------------------------------------------------------------## ax = plt.subplot(2, num, i + 1 + num) ## this will display reconstructed images plt.imshow(images_decode[i]) plt.title("Autoencoder Reconstructed images") plt.gray() ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) plt.show() {"mode":"full","isActive":false}

What Users are saying..

profile image

Ed Godalle

Director Data Analytics at EY / EY Tech
linkedin profile url

I am the Director of Data Analytics with over 10+ years of IT experience. I have a background in SQL, Python, and Big Data working with Accenture, IBM, and Infosys. I am looking to enhance my skills... Read More

Relevant Projects

Stock Price Prediction Project using LSTM and RNN
Learn how to predict stock prices using RNN and LSTM models. Understand deep learning concepts and apply them to real-world financial data for accurate forecasting.

NLP Project for Multi Class Text Classification using BERT Model
In this NLP Project, you will learn how to build a multi-class text classification model using using the pre-trained BERT model.

Recommender System Machine Learning Project for Beginners-3
Content Based Recommender System Project - Building a Content-Based Product Recommender App with Streamlit

Time Series Project to Build a Multiple Linear Regression Model
Learn to build a Multiple linear regression model in Python on Time Series Data

Walmart Sales Forecasting Data Science Project
Data Science Project in R-Predict the sales for each department using historical markdown data from the Walmart dataset containing data of 45 Walmart stores.

Learn to Build a Siamese Neural Network for Image Similarity
In this Deep Learning Project, you will learn how to build a siamese neural network with Keras and Tensorflow for Image Similarity.

Build OCR from Scratch Python using YOLO and Tesseract
In this deep learning project, you will learn how to build your custom OCR (optical character recognition) from scratch by using Google Tesseract and YOLO to read the text from any images.

Deep Learning Project for Beginners with Source Code Part 1
Learn to implement deep neural networks in Python .

AWS MLOps Project for Gaussian Process Time Series Modeling
MLOps Project to Build and Deploy a Gaussian Process Time Series Model in Python on AWS

Time Series Forecasting with LSTM Neural Network Python
Deep Learning Project- Learn to apply deep learning paradigm to forecast univariate time series data.