What does optimizer zero grad do in pytorch

This recipe explains what does optimizer zero grad do in pytorch

Recipe Objective

What does optimizer.zero_grad do in pytorch?

As we have discussed earlier only about torch.optim package, in this we are having the optim.zero_grad package which will zero all the gradients of the variable basically it will update the learnable weights of the model. We can also say it will sets the gardients of all the optimized torch tensors to zero. Lets understand with the practical implementation.

Step 1 - Import library

import torch

Step 2 - Define parameters

batch, dim_in, dim_h, dim_out = 64, 1000, 100, 10

Here we are defining various parameters which are as follows: batch - batch size dim_in - Input dimension. dim_out - Output dimension. dim_h - hidden dimension.

Step 3 - Create Random tensors

input_X = torch.randn(batch, dim_in)
output_Y = torch.randn(batch, dim_out)

Here we are creating random tensors for holding the input and output data.

Step 4 - Define model and loss function

Adam_model = torch.nn.Sequential( torch.nn.Linear(dim_in, dim_h), torch.nn.ReLU(), torch.nn.Linear(dim_h, dim_out), )
loss_fn = torch.nn.MSELoss(reduction='sum')

Step 5 - Define learning rate

rate_learning = 0.001

Step 6 - Initialize optimizer

optim = torch.optim.Adam(Adam_model.parameters(), lr=rate_learning)

Here we are Initializing our optimizer by using the "optim" package which will update the weights of the model for us. We are using SGD optimizer here the "optim" package which consist of many optimization algorithms.

Step 7 - Forward pass

for values in range(1000):
   pred_y = Adam_model(input_X)
   loss = loss_fn(pred_y, output_Y)
   if values % 100 == 99:
     print(values, loss.item())

99 647.01220703125
199 647.01220703125
299 647.01220703125
399 647.01220703125
499 647.01220703125
599 647.01220703125
699 647.01220703125
799 647.01220703125
899 647.01220703125
999 647.01220703125

Here we are computing the predicted y by passing input_X to the model, after that computing the loss and then printing it.

Step 8 - Zero all gradients

zero_grad = optim.zero_grad()

Here before the backward pass we must zero all the gradients for the variables it will update which are nothing but the learnable weights of the model.

What Users are saying..

profile image

Ed Godalle

Director Data Analytics at EY / EY Tech
linkedin profile url

I am the Director of Data Analytics with over 10+ years of IT experience. I have a background in SQL, Python, and Big Data working with Accenture, IBM, and Infosys. I am looking to enhance my skills... Read More

Relevant Projects

Deep Learning Project- Real-Time Fruit Detection using YOLOv4
In this deep learning project, you will learn to build an accurate, fast, and reliable real-time fruit detection system using the YOLOv4 object detection model for robotic harvesting platforms.

Personalized Medicine: Redefining Cancer Treatment
In this Personalized Medicine Machine Learning Project you will learn to classify genetic mutations on the basis of medical literature into 9 classes.

Build a Collaborative Filtering Recommender System in Python
Use the Amazon Reviews/Ratings dataset of 2 Million records to build a recommender system using memory-based collaborative filtering in Python.

PyTorch Project to Build a GAN Model on MNIST Dataset
In this deep learning project, you will learn how to build a GAN Model on MNIST Dataset for generating new images of handwritten digits.

BERT Text Classification using DistilBERT and ALBERT Models
This Project Explains how to perform Text Classification using ALBERT and DistilBERT

CycleGAN Implementation for Image-To-Image Translation
In this GAN Deep Learning Project, you will learn how to build an image to image translation model in PyTorch with Cycle GAN.

Build a Similar Images Finder with Python, Keras, and Tensorflow
Build your own image similarity application using Python to search and find images of products that are similar to any given product. You will implement the K-Nearest Neighbor algorithm to find products with maximum similarity.

Build a Face Recognition System in Python using FaceNet
In this deep learning project, you will build your own face recognition system in Python using OpenCV and FaceNet by extracting features from an image of a person's face.

OpenCV Project to Master Advanced Computer Vision Concepts
In this OpenCV project, you will learn to implement advanced computer vision concepts and algorithms in OpenCV library using Python.

Abstractive Text Summarization using Transformers-BART Model
Deep Learning Project to implement an Abstractive Text Summarizer using Google's Transformers-BART Model to generate news article headlines.