What is batch normalization in keras?

This recipe explains what is batch normalization in keras

Recipe Objective

In machine learning, our main motive is to create a model and predict the output. Here in deep learning and neural network, there may be a problem of internal covariate shift between the layers. Batch normalization applies a transformation that maintains the mean output close to 0 and the output standard deviation close to 1.

So this recipe is a short example of batch normalization in keras??

Learn to Implement Customer Churn Prediction Using Machine Learning in Python

Step 1 - Import the library

import pandas as pd import numpy as np from keras.datasets import mnist from sklearn.model_selection import train_test_split from keras.models import Sequential from keras.layers import Dense from keras.layers import Dropout from keras.layers import BatchNormalization

We have imported pandas, numpy, mnist(which is the dataset), train_test_split, Sequential,BatchNormalization, Dense and Dropout. We will use these later in the recipe.

Step 2 - Loading the Dataset

Here we have used the inbuilt mnist dataset and stored the train data in X_train and y_train. We have used X_test and y_test to store the test data. (X_train, y_train), (X_test, y_test) = mnist.load_data()

Step 3 - Model and Batch Normalization

We have created an object model for sequential model. We can use two args i.e layers and name. model = Sequential() Now, We are adding the layers by using 'add'. We can specify the type of layer, activation function to be used and many other things while adding the layer.
Here we are adding batch normalization after every layer which will reduce the internal covariate shift between the layers. model = models.Sequential() model.add(Dense(512, activation='relu')) model.add(BatchNormalization()) model.add(Dropout(0.5)) model.add(BatchNormalization()) model.add(Dense(256, activation='relu')) model.add(BatchNormalization()) model.add(Dropout(0.25)) model.add(BatchNormalization()) model.add(Dense(10))

Step 4 - Compiling the model

We can compile a model by using compile attribute. Let us first look at its parameters before using it.

  • optimizer : In this, we can pass the optimizer we want to use. There is various optimizer like SGD, Adam etc.
  • loss : In this, we can pass a loss function which we want for the model
  • metrics : In this, we can pass the metric on which we want the model to be scored

model.compile(optimizer='Adam', loss='categorical_crossentropy', metrics=['accuracy'])

Step 5 - Fitting the model

We can fit a model on the data we have and can use the model after that. Here we are using the data which we have splitted i.e the training data for fitting the model.
While fitting we can pass various parameters like batch_size, epochs, verbose, validation_data and so on. model.fit(X_train, y_train, batch_size=128, epochs=2, verbose=1, validation_data=(X_test, y_test) model.summary()

Step 6 - Evaluating the model

After fitting a model we want to evaluate the model. Here we are using model.evaluate to evaluate the model and it will give us the loss and the accuracy. Here we have also printed the score. score = model.evaluate(X_test, y_test, verbose=0) print('Test loss:', score[0]) print('Test accuracy:', score[1])

Step 7 - Predicting the output

Finally we are predicting the output for this we are using another part of the data that we get from test_train_split i.e. test data. We will use it and predict the output. y_pred = model.predict(X_test) print(y_pred) As an output we get:

Epoch 1/2
469/469 [==============================] - 7s 16ms/step - loss: 7.6169 - accuracy: 0.2197 - val_loss: 8.3000 - val_accuracy: 0.1068
Epoch 2/2
469/469 [==============================] - 7s 15ms/step - loss: 8.2217 - accuracy: 0.1717 - val_loss: 8.7116 - val_accuracy: 0.1542
Model: "sequential_4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_10 (Dense)             (None, 512)               401920    
_________________________________________________________________
batch_normalization_12 (Batc (None, 512)               2048      
_________________________________________________________________
dropout_6 (Dropout)          (None, 512)               0         
_________________________________________________________________
batch_normalization_13 (Batc (None, 512)               2048      
_________________________________________________________________
dense_11 (Dense)             (None, 256)               131328    
_________________________________________________________________
batch_normalization_14 (Batc (None, 256)               1024      
_________________________________________________________________
dropout_7 (Dropout)          (None, 256)               0         
_________________________________________________________________
batch_normalization_15 (Batc (None, 256)               1024      
_________________________________________________________________
dense_12 (Dense)             (None, 10)                2570      
=================================================================
Total params: 541,962
Trainable params: 538,890
Non-trainable params: 3,072
_________________________________________________________________
Test loss: 8.711593627929688
Test accuracy: 0.1542000025510788

[[-0.8443254  -0.12326024  2.869469   ...  0.04755732  2.2749598
  -3.7791016 ]
 [-0.65643924  0.03457718 -6.4518514  ... -0.05437452 -2.6341794
   1.9271061 ]
 [-1.0447047   3.3620064  -5.2135124  ...  3.5870833  -2.537289
   2.6774516 ]
 ...
 [-3.4818666   2.5188947   1.2750722  ... -2.3888552   3.608779
  -0.5627145 ]
 [-3.7964191   0.05416028 -1.3229418  ... -1.4461676   0.8728197
   2.606545  ]
 [-3.161486    0.7417487  -2.9847188  ... -4.3610487  -1.7754345
  -3.0061972 ]]

What Users are saying..

profile image

Ed Godalle

Director Data Analytics at EY / EY Tech
linkedin profile url

I am the Director of Data Analytics with over 10+ years of IT experience. I have a background in SQL, Python, and Big Data working with Accenture, IBM, and Infosys. I am looking to enhance my skills... Read More

Relevant Projects

MLOps Project to Deploy Resume Parser Model on Paperspace
In this MLOps project, you will learn how to deploy a Resume Parser Streamlit Application on Paperspace Private Cloud.

Digit Recognition using CNN for MNIST Dataset in Python
In this deep learning project, you will build a convolutional neural network using MNIST dataset for handwritten digit recognition.

CycleGAN Implementation for Image-To-Image Translation
In this GAN Deep Learning Project, you will learn how to build an image to image translation model in PyTorch with Cycle GAN.

Learn to Build a Polynomial Regression Model from Scratch
In this Machine Learning Regression project, you will learn to build a polynomial regression model to predict points scored by the sports team.

Image Segmentation using Mask R-CNN with Tensorflow
In this Deep Learning Project on Image Segmentation Python, you will learn how to implement the Mask R-CNN model for early fire detection.

Build a Customer Churn Prediction Model using Decision Trees
Develop a customer churn prediction model using decision tree machine learning algorithms and data science on streaming service data.

Classification Projects on Machine Learning for Beginners - 2
Learn to implement various ensemble techniques to predict license status for a given business.

MLOps using Azure Devops to Deploy a Classification Model
In this MLOps Azure project, you will learn how to deploy a classification machine learning model to predict the customer's license status on Azure through scalable CI/CD ML pipelines.

Machine Learning Project to Forecast Rossmann Store Sales
In this machine learning project you will work on creating a robust prediction model of Rossmann's daily sales using store, promotion, and competitor data.

Word2Vec and FastText Word Embedding with Gensim in Python
In this NLP Project, you will learn how to use the popular topic modelling library Gensim for implementing two state-of-the-art word embedding methods Word2Vec and FastText models.