How to save a model using pytorch

This recipe helps you save a model using pytorch

Recipe Objective

How to save model using PyTorch?

This is achieved by using the torch.save function, which will save a serialized object to the disk, For serialization it uses the python's pickle utility. All kind of objects like models, dictionaries, tensors can be saved using this.

Get Access to Plant Species Identification Project using Machine Learning

Step 1 - Import library

import torch
import torch.nn as nn
import torch.optim as optim

Step 2 - Define Model

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

   def pass_forward(self, X_data):
       X_data = self.pool(F.relu(self.conv1(X_data)))
       X_data = self.pool(F.relu(self.conv2(X_data)))
       X_data = X_data.view(-1, 16 * 5 * 5)
       X_data = F.relu(self.fc1(X_data))
       X_data = F.relu(self.fc2(X_data))
       X_data = self.fc3(X_data)
       return X_data
Model = MyModel()
print("This is our Model parameters:",Model)

This is our Model parameters: MyModel(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)

Step 3 - Initializing optimizer

optim = optim.SGD(Model.parameters(), lr=0.01, momentum=0.9)

Step 4 - Accessing Model

print("Accessing the model state_dict")
for values in Model.state_dict():
  print(values, "\t", Model.state_dict()[values].size())

Accessing the model state_dict
conv1.weight 	 torch.Size([6, 3, 5, 5])
conv1.bias 	 torch.Size([6])
conv2.weight 	 torch.Size([16, 6, 5, 5])
conv2.bias 	 torch.Size([16])
fc1.weight 	 torch.Size([120, 400])
fc1.bias 	 torch.Size([120])
fc2.weight 	 torch.Size([84, 120])
fc2.bias 	 torch.Size([84])
fc3.weight 	 torch.Size([10, 84])
fc3.bias 	 torch.Size([10])

Step 5 - Accessing Optimizer

print("Accessing the optimizers state_dict")
for elements in optim.state_dict():
  print(elements, "\t", optim.state_dict()[elements])

Accessing the optimizers state_dict
state 	 {}
param_groups 	 [{'lr': 0.01, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}]

Step 6 - Save the model

torch.save(Model.state_dict, "Model.pickle") {"mode":"full","isActive":false}

What Users are saying..

profile image

Anand Kumpatla

Sr Data Scientist @ Doubleslash Software Solutions Pvt Ltd
linkedin profile url

ProjectPro is a unique platform and helps many people in the industry to solve real-life problems with a step-by-step walkthrough of projects. A platform with some fantastic resources to gain... Read More

Relevant Projects

Build an optimal End-to-End MLOps Pipeline and Deploy on GCP
Learn how to build and deploy an end-to-end optimal MLOps Pipeline for Loan Eligibility Prediction Model in Python on GCP

Model Deployment on GCP using Streamlit for Resume Parsing
Perform model deployment on GCP for resume parsing model using Streamlit App.

Build a Face Recognition System in Python using FaceNet
In this deep learning project, you will build your own face recognition system in Python using OpenCV and FaceNet by extracting features from an image of a person's face.

Customer Market Basket Analysis using Apriori and Fpgrowth algorithms
In this data science project, you will learn how to perform market basket analysis with the application of Apriori and FP growth algorithms based on the concept of association rule learning.

Learn to Build Generative Models Using PyTorch Autoencoders
In this deep learning project, you will learn how to build a Generative Model using Autoencoders in PyTorch

Build a Similar Images Finder with Python, Keras, and Tensorflow
Build your own image similarity application using Python to search and find images of products that are similar to any given product. You will implement the K-Nearest Neighbor algorithm to find products with maximum similarity.

Build a Multi Touch Attribution Machine Learning Model in Python
Identifying the ROI on marketing campaigns is an essential KPI for any business. In this ML project, you will learn to build a Multi Touch Attribution Model in Python to identify the ROI of various marketing efforts and their impact on conversions or sales..

Build CNN Image Classification Models for Real Time Prediction
Image Classification Project to build a CNN model in Python that can classify images into social security cards, driving licenses, and other key identity information.

GCP MLOps Project to Deploy ARIMA Model using uWSGI Flask
Build an end-to-end MLOps Pipeline to deploy a Time Series ARIMA Model on GCP using uWSGI and Flask

Build an End-to-End AWS SageMaker Classification Model
MLOps on AWS SageMaker -Learn to Build an End-to-End Classification Model on SageMaker to predict a patient’s cause of death.