How to save a tensorflow model

This recipe helps you save a tensorflow model

Recipe Objective

How to save a tensorflow model?

This can achieved in tensorflow, here saving a model means we can share our model and others can recreate it. Most of the machine learning practitioners share it while publishing the research models and techniques. To do this firstly we have to create a Model and then we should have the trained weights or the parameters for model.

Step 1 - Install required library

!pip install -q pyyaml h5py

Step 2 - Import library

import os import tensorflow as tf from tensorflow import keras

Step 3 - Load the Data

(images_data_train, images_train_labels), (images_data_test, images_test_labels) = tf.keras.datasets.mnist.load_data() images_train_labels = images_train_labels[:1000] images_test_labels = images_test_labels[:1000] images_data_train = images_data_train[:1000].reshape(-1, 28 * 28) / 255.0 images_data_test = images_data_test[:1000].reshape(-1, 28 * 28) / 255.0

Step 4 - Define the model

# Define a simple sequential model def Make_model(): My_model = tf.keras.models.Sequential([ keras.layers.Dense(512, activation='relu', input_shape=(784,)), keras.layers.Dropout(0.2), keras.layers.Dense(10) ]) My_model.compile(optimizer='adam', loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=[tf.metrics.SparseCategoricalAccuracy()]) return My_model # Create a basic model instance My_model = Make_model() # Display the model's architecture My_model.summary()

Model: "sequential_4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_8 (Dense)              (None, 512)               401920    
_________________________________________________________________
dropout_4 (Dropout)          (None, 512)               0         
_________________________________________________________________
dense_9 (Dense)              (None, 10)                5130      
=================================================================
Total params: 407,050
Trainable params: 407,050
Non-trainable params: 0
_________________________________________________________________

Step 5 - Save the Checkpoints

path_checkpoint = "training_1/cp.ckpt" directory_checkpoint = os.path.dirname(path_checkpoint) callback = tf.keras.callbacks.ModelCheckpoint(filepath=path_checkpoint, save_weights_only=True, verbose=1) My_model.fit(images_data_train, images_train_labels, epochs=10, validation_data=(images_data_test, images_test_labels), callbacks=[callback])

Epoch 1/10
32/32 [==============================] - 1s 12ms/step - loss: 1.5851 - sparse_categorical_accuracy: 0.5323 - val_loss: 0.6799 - val_sparse_categorical_accuracy: 0.8050

Epoch 00001: saving model to training_1/cp.ckpt
Epoch 2/10
32/32 [==============================] - 0s 7ms/step - loss: 0.4546 - sparse_categorical_accuracy: 0.8551 - val_loss: 0.5102 - val_sparse_categorical_accuracy: 0.8480

Epoch 00002: saving model to training_1/cp.ckpt
Epoch 3/10
32/32 [==============================] - ETA: 0s - loss: 0.2908 - sparse_categorical_accuracy: 0.9217WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate
WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate
WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.
32/32 [==============================] - 1s 17ms/step - loss: 0.2907 - sparse_categorical_accuracy: 0.9217 - val_loss: 0.4689 - val_sparse_categorical_accuracy: 0.8560

Epoch 00003: saving model to training_1/cp.ckpt
Epoch 4/10
32/32 [==============================] - 0s 7ms/step - loss: 0.1676 - sparse_categorical_accuracy: 0.9643 - val_loss: 0.4308 - val_sparse_categorical_accuracy: 0.8610

Epoch 00004: saving model to training_1/cp.ckpt
Epoch 5/10
32/32 [==============================] - 0s 8ms/step - loss: 0.1548 - sparse_categorical_accuracy: 0.9681 - val_loss: 0.4265 - val_sparse_categorical_accuracy: 0.8590

Epoch 00005: saving model to training_1/cp.ckpt
Epoch 6/10
32/32 [==============================] - 0s 7ms/step - loss: 0.1380 - sparse_categorical_accuracy: 0.9767 - val_loss: 0.4116 - val_sparse_categorical_accuracy: 0.8620

Epoch 00006: saving model to training_1/cp.ckpt
Epoch 7/10
32/32 [==============================] - 0s 7ms/step - loss: 0.0871 - sparse_categorical_accuracy: 0.9902 - val_loss: 0.3967 - val_sparse_categorical_accuracy: 0.8690

Epoch 00007: saving model to training_1/cp.ckpt
Epoch 8/10
32/32 [==============================] - 0s 7ms/step - loss: 0.0598 - sparse_categorical_accuracy: 0.9938 - val_loss: 0.3946 - val_sparse_categorical_accuracy: 0.8750

Epoch 00008: saving model to training_1/cp.ckpt
Epoch 9/10
32/32 [==============================] - 0s 7ms/step - loss: 0.0431 - sparse_categorical_accuracy: 0.9995 - val_loss: 0.3989 - val_sparse_categorical_accuracy: 0.8730

Epoch 00009: saving model to training_1/cp.ckpt
Epoch 10/10
32/32 [==============================] - 0s 8ms/step - loss: 0.0378 - sparse_categorical_accuracy: 1.0000 - val_loss: 0.4008 - val_sparse_categorical_accuracy: 0.8720

Epoch 00010: saving model to training_1/cp.ckpt

Here we are defining a path to save the checkpoints during training, then we have created a call back which will saves the models weight. After that train the model with new callback. After running this it will generate some warnings related to saving the state of the optimizer, these warnings are in place to discourage outdated usage and can be ignored

Step 6 - Check the checkpoint directory

ls {directory_checkpoint} checkpoint cp.ckpt.data-00000-of-00001 cp.ckpt.index

Step 7 - Create model instance and evaluate

My_model = Make_model() loss, accuracy_d = My_model.evaluate(images_data_test, images_test_labels, verbose=2) print("Untrained model, accuracy: {:5.2f}%".format(100 * accuracy_d))

32/32 - 0s - loss: 2.4090 - sparse_categorical_accuracy: 0.1020
Untrained model, accuracy: 10.20%

Step 8 - load the weights and re-evaluate

My_model.load_weights(path_checkpoint) loss, accuracy_d = My_model.evaluate(images_data_test, images_test_labels, verbose=2) print("Restored model, accuracy: {:5.2f}%".format(100 * accuracy_d))

32/32 - 0s - loss: 0.4008 - sparse_categorical_accuracy: 0.8720
Restored model, accuracy: 87.20%

Step 9 - Save the model

!mkdir -p saved_model My_model.save('saved_model/my_model')

INFO:tensorflow:Assets written to: saved_model/my_model/assets

What Users are saying..

profile image

Anand Kumpatla

Sr Data Scientist @ Doubleslash Software Solutions Pvt Ltd
linkedin profile url

ProjectPro is a unique platform and helps many people in the industry to solve real-life problems with a step-by-step walkthrough of projects. A platform with some fantastic resources to gain... Read More

Relevant Projects

Build a Multi Class Image Classification Model Python using CNN
This project explains How to build a Sequential Model that can perform Multi Class Image Classification in Python using CNN

NLP Project to Build a Resume Parser in Python using Spacy
Use the popular Spacy NLP python library for OCR and text classification to build a Resume Parser in Python.

Build a Music Recommendation Algorithm using KKBox's Dataset
Music Recommendation Project using Machine Learning - Use the KKBox dataset to predict the chances of a user listening to a song again after their very first noticeable listening event.

Build a Multi Touch Attribution Machine Learning Model in Python
Identifying the ROI on marketing campaigns is an essential KPI for any business. In this ML project, you will learn to build a Multi Touch Attribution Model in Python to identify the ROI of various marketing efforts and their impact on conversions or sales..

Build a Review Classification Model using Gated Recurrent Unit
In this Machine Learning project, you will build a classification model in python to classify the reviews of an app on a scale of 1 to 5 using Gated Recurrent Unit.

Build an Image Classifier for Plant Species Identification
In this machine learning project, we will use binary leaf images and extracted features, including shape, margin, and texture to accurately identify plant species using different benchmark classification techniques.

Image Segmentation using Mask R-CNN with Tensorflow
In this Deep Learning Project on Image Segmentation Python, you will learn how to implement the Mask R-CNN model for early fire detection.

Skip Gram Model Python Implementation for Word Embeddings
Skip-Gram Model word2vec Example -Learn how to implement the skip gram algorithm in NLP for word embeddings on a set of documents.

AWS MLOps Project to Deploy Multiple Linear Regression Model
Build and Deploy a Multiple Linear Regression Model in Python on AWS

MLOps Project to Build Search Relevancy Algorithm with SBERT
In this MLOps SBERT project you will learn to build and deploy an accurate and scalable search algorithm on AWS using SBERT and ANNOY to enhance search relevancy in news articles.