What is a state dict in pytorch

This recipe explains what is a state dict in pytorch

Recipe Objective

What is a state_dict in PyTorch?

The State_dict is nothing but a simple python dictionary object which is used for saving or loading models from PyTorch. As in PyTorch the weights and biases or the learnable parameters of neural networks or "torch.nn.Module" model are contained in the models parameters which are accepted by model.parameter() function, the dictionary i.e state_dict maps the each layer to its parameter tensor. These dictionary object can be easily updated, saved, altered, and restored adding a great deal of modularity to the PyTorch models and the optimizers as well. Lets understand this with practical implementation.

Step 1 - Import library

import torch
import torch.nn as nn
import torch.optim as optim

Step 2 - Define and Initialize Neural network

class Neuralnet(nn.Module):
    def __init__(self):
        super(Neuralnet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
    def pass_forward(self, X_data):
        X_data = self.pool(F.relu(self.conv1(X_data)))
        X_data = self.pool(F.relu(self.conv2(X_data)))
        X_data = X_data.view(-1, 16 * 5 * 5)
        X_data = F.relu(self.fc1(X_data))
        X_data = F.relu(self.fc2(X_data))
        X_data = self.fc3(X_data)
        return X_data
network = Neuralnet()
print("This is our neural network parameters:",network)

This is our neural network parameters: Neuralnet(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)

Step 3 - Initializing optimizer

state_optim = optim.SGD(network.parameters(), lr=0.01, momentum=0.9)

Step 4 - Accessing Model

print("Accessing the model state_dict")
for values in network.state_dict():
  print(values, "\t", network.state_dict()[values].size())

Accessing the model state_dict
conv1.weight 	 torch.Size([6, 3, 5, 5])
conv1.bias 	 torch.Size([6])
conv2.weight 	 torch.Size([16, 6, 5, 5])
conv2.bias 	 torch.Size([16])
fc1.weight 	 torch.Size([120, 400])
fc1.bias 	 torch.Size([120])
fc2.weight 	 torch.Size([84, 120])
fc2.bias 	 torch.Size([84])
fc3.weight 	 torch.Size([10, 84])
fc3.bias 	 torch.Size([10])

Step 5 - Accessing Optimizer

print("Accessing the optimizers state_dict")
for elements in state_optim.state_dict():
  print(elements, "\t", state_optim.state_dict()[elements])

Accessing the optimizers state_dict
state 	 {}
param_groups 	 [{'lr': 0.01, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}]

{"mode":"full","isActive":false}

What Users are saying..

profile image

Jingwei Li

Graduate Research assistance at Stony Brook University
linkedin profile url

ProjectPro is an awesome platform that helps me learn much hands-on industrial experience with a step-by-step walkthrough of projects. There are two primary paths to learn: Data Science and Big Data.... Read More

Relevant Projects

Build CI/CD Pipeline for Machine Learning Projects using Jenkins
In this project, you will learn how to create a CI/CD pipeline for a search engine application using Jenkins.

MLOps Project to Build Search Relevancy Algorithm with SBERT
In this MLOps SBERT project you will learn to build and deploy an accurate and scalable search algorithm on AWS using SBERT and ANNOY to enhance search relevancy in news articles.

Build a Multi ClassText Classification Model using Naive Bayes
Implement the Naive Bayes Algorithm to build a multi class text classification model in Python.

Linear Regression Model Project in Python for Beginners Part 2
Machine Learning Linear Regression Project for Beginners in Python to Build a Multiple Linear Regression Model on Soccer Player Dataset.

Deploy Transformer BART Model for Text summarization on GCP
Learn to Deploy a Machine Learning Model for the Abstractive Text Summarization on Google Cloud Platform (GCP)

BigMart Sales Prediction ML Project in Python
The goal of the BigMart Sales Prediction ML project is to build and evaluate different predictive models and determine the sales of each product at a store.

Build an Image Classifier for Plant Species Identification
In this machine learning project, we will use binary leaf images and extracted features, including shape, margin, and texture to accurately identify plant species using different benchmark classification techniques.

Skip Gram Model Python Implementation for Word Embeddings
Skip-Gram Model word2vec Example -Learn how to implement the skip gram algorithm in NLP for word embeddings on a set of documents.

Time Series Forecasting with LSTM Neural Network Python
Deep Learning Project- Learn to apply deep learning paradigm to forecast univariate time series data.

Credit Card Fraud Detection as a Classification Problem
In this data science project, we will predict the credit card fraud in the transactional dataset using some of the predictive models.