How to save and reload a deep learning model in Pytorch?

How to save and reload a deep learning model in Pytorch?

How to save and reload a deep learning model in Pytorch?

This Pytorch recipe provides you a solution for saving and loading Pytorch models - entire models or just the parameters.


This recipe provides options to save and reload an entire model or just the parameters of the model. While reloading this recipe copies the parameter from 1 net to another net. There are 3 main functions involved in saving and loading a model in pytorch.

1. This saves a serialized object to disk. It uses python's pickle utility for serialization. Models, tensors and dictionaries can be saved using this function.
2. torch.load: torch.load: Uses pickle's unpickling facilities to deserialize pickled object files to memory. This function also facilitates the device to load the data into.
3. torch.nn.Module.load_state_dict: Loads a model's parameter dictionary using a deserialized state_dict. The learnable parameters (i.e. weights and biases) of an torch.nn.Module model are contained in the model's parameters (accessed with model.parameters()). A state_dict is simply a Python dictionary object that maps each layer to its parameter tensor.

What is PyTorch ?
Pytorch is a Python-based scientific computing package that uses the power of graphics processing units and can replace the numpy library. It is also a very popular deep learning research platform built for flexibility and speed. You can use other Python packages such as NumPy, SciPy to extend PyTorch functionalities.

What is Deep Learning Model ?
Deep learning is a subset of machine learning. Deep learning uses neural networks to make predictions. A neural network takes inputs, which are then processed using hidden layers using weights that are adjusted during training. The model then outputs a prediction.

In [72]:
import torch
from torch.autograd import Variable
import matplotlib.pyplot as plt
%matplotlib inline

torch.manual_seed(1)    # reproducible
<torch._C.Generator at 0x12214f310>
In [73]:
#sample data
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)  # x data (tensor), shape=(100, 1)
y = x.pow(2) + 0.2*torch.rand(x.size())  # noisy y data (tensor), shape=(100, 1)
x, y = Variable(x, requires_grad=False), Variable(y, requires_grad=False)
In [74]:
def save():
    # save net1
    net1 = torch.nn.Sequential(
        torch.nn.Linear(1, 10),
        torch.nn.Linear(10, 1)
    optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)
    loss_func = torch.nn.MSELoss()

    for t in range(100):
        prediction = net1(x)
        loss = loss_func(prediction, y)

    # plot result
    plt.figure(1, figsize=(10, 3))
    plt.plot(,, 'r-', lw=5)

    # 2 ways to save the net, 'net.pkl')  # save entire net, 'net_params.pkl')   # save only the parameters
In [75]:
def restore_net():
    # restore entire net1 to net2
    net2 = torch.load('net.pkl')
    prediction = net2(x)

    # plot result
    plt.plot(,, 'r-', lw=5)
In [76]:
def restore_params():
    # restore only the parameters in net1 to net3
    net3 = torch.nn.Sequential(
        torch.nn.Linear(1, 10),
        torch.nn.Linear(10, 1)

    # copy net1's parameters into net3
    prediction = net3(x)

    # plot result
    plt.plot(,, 'r-', lw=5)
In [77]:
# save net1
# restore entire net (may slow)
# restore only the net parameters

Relevant Projects

Credit Card Fraud Detection as a Classification Problem
In this data science project, we will predict the credit card fraud in the transactional dataset using some of the predictive models.

Deep Learning with Keras in R to Predict Customer Churn
In this deep learning project, we will predict customer churn using Artificial Neural Networks and learn how to model an ANN in R with the keras deep learning package.

Zillow’s Home Value Prediction (Zestimate)
Data Science Project in R -Build a machine learning algorithm to predict the future sale prices of homes.

Learn to prepare data for your next machine learning project
Text data requires special preparation before you can start using it for any machine learning project.In this ML project, you will learn about applying Machine Learning models to create classifiers and learn how to make sense of textual data.

Choosing the right Time Series Forecasting Methods
There are different time series forecasting methods to forecast stock price, demand etc. In this machine learning project, you will learn to determine which forecasting method to be used when and how to apply with time series forecasting example.

Forecast Inventory demand using historical sales data in R
In this machine learning project, you will develop a machine learning model to accurately forecast inventory demand based on historical sales data.

Resume parsing with Machine learning - NLP with Python OCR and Spacy
In this machine learning resume parser example we use the popular Spacy NLP python library for OCR and text classification.

Walmart Sales Forecasting Data Science Project
Data Science Project in R-Predict the sales for each department using historical markdown data from the Walmart dataset containing data of 45 Walmart stores.

Human Activity Recognition Using Smartphones Data Set
In this deep learning project, you will build a classification system where to precisely identify human fitness activities.

Time Series Forecasting with LSTM Neural Network Python
Deep Learning Project- Learn to apply deep learning paradigm to forecast univariate time series data.