How to train a LSTM using tf

This recipe helps you train a LSTM using tf

Recipe Objective

How to train an LSTM using tf?

The LSTM here stands for "Long-Short term memory" network which is a recurrent neural network. These networks are trained by using the backpropagation through time and then overcomes the fading the gradient problem. It can be used to create large recurrent neural network which in turn can be used to address the difficult sequence problems in machine learning achieve the state of art results. The LSTM networks are having memory block which are connected through layers instead of the neurons. To build an LSTM model we are going to use the keras in tensorflow and for importing "keras.layers.LSTM" function is being used.

Complete Guide to Tensorflow for Deep Learning with Python for Free

Step 1 - Import the library

import numpy as np import tensorflow as tf from tensorflow import keras from tensorflow.keras import layers

Step 2 - Initialize Model

LSTM_model = keras.Sequential()

Here we are going to use simple sequential model which will processes the sequences of integers and then it will embeds the each integer into 64-dimensional vector after that it will again processes the sequence of vectors using the LSTM layer

Step 3 - Add layers

LSTM_model.add(layers.Embedding(input_dim=1000, output_dim=64)) LSTM_model.add(layers.LSTM(128)) LSTM_model.add(layers.Dense(10))

Here in the above we have added layers in our model, firstly we have added embedding layer which will have a input of size 1000 and output embedding dimension which will have the size 64. Then we have added a LSTM layer with internal units(128). At last we have added a dense layer with 10 units.

Step 4 - Check Model summary

LSTM_model.summary()

Model: "sequential_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
embedding_5 (Embedding)      (None, None, 64)          64000     
_________________________________________________________________
lstm_2 (LSTM)                (None, 128)               98816     
_________________________________________________________________
dense_4 (Dense)              (None, 10)                1290      
=================================================================
Total params: 164,106
Trainable params: 164,106
Non-trainable params: 0
_________________________________________________________________

{"mode":"full","isActive":false}

What Users are saying..

profile image

Savvy Sahai

Data Science Intern, Capgemini
linkedin profile url

As a student looking to break into the field of data engineering and data science, one can get really confused as to which path to take. Very few ways to do it are Google, YouTube, etc. I was one of... Read More

Relevant Projects

Word2Vec and FastText Word Embedding with Gensim in Python
In this NLP Project, you will learn how to use the popular topic modelling library Gensim for implementing two state-of-the-art word embedding methods Word2Vec and FastText models.

Credit Card Default Prediction using Machine learning techniques
In this data science project, you will predict borrowers chance of defaulting on credit loans by building a credit score prediction model.

Build a Customer Churn Prediction Model using Decision Trees
Develop a customer churn prediction model using decision tree machine learning algorithms and data science on streaming service data.

Ola Bike Rides Request Demand Forecast
Given big data at taxi service (ride-hailing) i.e. OLA, you will learn multi-step time series forecasting and clustering with Mini-Batch K-means Algorithm on geospatial data to predict future ride requests for a particular region at a given time.

Deep Learning Project- Real-Time Fruit Detection using YOLOv4
In this deep learning project, you will learn to build an accurate, fast, and reliable real-time fruit detection system using the YOLOv4 object detection model for robotic harvesting platforms.

Isolation Forest Model and LOF for Anomaly Detection in Python
Credit Card Fraud Detection Project - Build an Isolation Forest Model and Local Outlier Factor (LOF) in Python to identify fraudulent credit card transactions.

Build a Review Classification Model using Gated Recurrent Unit
In this Machine Learning project, you will build a classification model in python to classify the reviews of an app on a scale of 1 to 5 using Gated Recurrent Unit.

Customer Market Basket Analysis using Apriori and Fpgrowth algorithms
In this data science project, you will learn how to perform market basket analysis with the application of Apriori and FP growth algorithms based on the concept of association rule learning.

Demand prediction of driver availability using multistep time series analysis
In this supervised learning machine learning project, you will predict the availability of a driver in a specific area by using multi step time series analysis.

Hands-On Approach to Master PyTorch Tensors with Examples
In this deep learning project, you will learn how to perform various operations on the building block of PyTorch : Tensors.