How to load a model using pytorch

This recipe helps you load a model using pytorch

Recipe Objective

How to load a model using PyTorch?

For this torch.load function is used for loading a model, this function uses the pickle's unpickling facilities to deserialize pickle object files to the memory.

Learn How to use XLNet for Text Classification

Step 1 - Import library

import torch
import torch.nn as nn
import torch.optim as optim

Step 2 - Define Model

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)

        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
    def pass_forward(self, X_data):
        X_data = self.pool(F.relu(self.conv1(X_data)))
        X_data = self.pool(F.relu(self.conv2(X_data)))
        X_data = X_data.view(-1, 16 * 5 * 5)
        X_data = F.relu(self.fc1(X_data))
        X_data = F.relu(self.fc2(X_data))
        X_data = self.fc3(X_data)
       return X_data

Model = MyModel()
print("This is our Model parameters:",Model)

This is our Model parameters: MyModel(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)

Step 3 - Initializing optimizer

optim = optim.SGD(Model.parameters(), lr=0.01, momentum=0.9)

Step 4 - Accessing Model

print("Accessing the model state_dict")
for values in Model.state_dict():
  print(values, "\t", Model.state_dict()[values].size())

Accessing the model state_dict
conv1.weight 	 torch.Size([6, 3, 5, 5])
conv1.bias 	 torch.Size([6])
conv2.weight 	 torch.Size([16, 6, 5, 5])
conv2.bias 	 torch.Size([16])
fc1.weight 	 torch.Size([120, 400])
fc1.bias 	 torch.Size([120])
fc2.weight 	 torch.Size([84, 120])
fc2.bias 	 torch.Size([84])
fc3.weight 	 torch.Size([10, 84])
fc3.bias 	 torch.Size([10])

Step 5 - Accessing Optimizer

print("Accessing the optimizers state_dict")
for elements in optim.state_dict():
  print(elements, "\t", optim.state_dict()[elements])

Accessing the optimizers state_dict
state 	 {}
param_groups 	 [{'lr': 0.01, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}]

Step 6 - Save the model

torch.save(Model.state_dict, "Model.pickle")

Step 7 - Load the model

load = torch.load("/content/Model.pickle")
load

 

{"mode":"full","isActive":false}

What Users are saying..

profile image

Ray han

Tech Leader | Stanford / Yale University
linkedin profile url

I think that they are fantastic. I attended Yale and Stanford and have worked at Honeywell,Oracle, and Arthur Andersen(Accenture) in the US. I have taken Big Data and Hadoop,NoSQL, Spark, Hadoop... Read More

Relevant Projects

AWS MLOps Project to Deploy Multiple Linear Regression Model
Build and Deploy a Multiple Linear Regression Model in Python on AWS

Time Series Forecasting Project-Building ARIMA Model in Python
Build a time series ARIMA model in Python to forecast the use of arrival rate density to support staffing decisions at call centres.

Build ARCH and GARCH Models in Time Series using Python
In this Project we will build an ARCH and a GARCH model using Python

Build Time Series Models for Gaussian Processes in Python
Time Series Project - A hands-on approach to Gaussian Processes for Time Series Modelling in Python

Build CNN for Image Colorization using Deep Transfer Learning
Image Processing Project -Train a model for colorization to make grayscale images colorful using convolutional autoencoders.

AWS Project to Build and Deploy LSTM Model with Sagemaker
In this AWS Sagemaker Project, you will learn to build a LSTM model on Sagemaker for sales forecasting while analyzing the impact of weather conditions on Sales.

Build CNN Image Classification Models for Real Time Prediction
Image Classification Project to build a CNN model in Python that can classify images into social security cards, driving licenses, and other key identity information.

Model Deployment on GCP using Streamlit for Resume Parsing
Perform model deployment on GCP for resume parsing model using Streamlit App.

Credit Card Fraud Detection as a Classification Problem
In this data science project, we will predict the credit card fraud in the transactional dataset using some of the predictive models.

Learn to Build an End-to-End Machine Learning Pipeline - Part 1
In this Machine Learning Project, you will learn how to build an end-to-end machine learning pipeline for predicting truck delays, addressing a major challenge in the logistics industry.