How to optimize a function using Adam in pytorch

This recipe helps you optimize a function using Adam in pytorch

Recipe Objective

How to optimize a function using Adam in pytorch?

The Adam optimizer is also an optimization techniques used for machine learning and deep learning, and comes under gradient decent algorithm. When working with large problem which involves a lot of data this method is really efficient for it. It requires less memory and is efficient, the optimizer is combination of momentum and RMSP algorithm which are gradient decent methodologies. The optimizer is relatively easy to configure where the default configuration parameters do well on most problems. For optimizing a function we are going to use torch.optim which is a package, implements numerous optimization algorithms. The several commonly used methods are already supported, and the interface is general enough so that more practical ones can be also easily integrated in future.

PyTorch vs Tensorflow - Which One Should You Choose For Your Next Deep Learning Project ?

Step 1 - Import library

import torch

Step 2 - Define parameters

batch, dim_in, dim_h, dim_out = 128, 2000, 200, 20

Here we are defining various parameters which are as follows:
batch - batch size
dim_in - Input dimension.
dim_out - Output dimension.
dim_h - hidden dimension.

Step 3 - Create Random tensors

input_X = torch.randn(batch, dim_in)
output_Y = torch.randn(batch, dim_out)

Here we are creating random tensors for holding the input and output data.

Step 4 - Define model and loss function

Adam_model = torch.nn.Sequential( torch.nn.Linear(dim_in, dim_h), torch.nn.ReLU(), torch.nn.Linear(dim_h, dim_out), )
loss_fn = torch.nn.MSELoss(reduction='sum')

Step 5 - Define learning rate

rate_learning = 1e-4

Step 6 - Initialize optimizer

optim = torch.optim.Adam(SGD_model.parameters(), lr=rate_learning)

Here we are Initializing our optimizer by using the "optim" package which will update the weights of the model for us. We are using SGD optimizer here the "optim" package which consist of many optimization algorithms.

Step 7 - Forward pass

for values in range(500):
   pred_y = Adam_model(input_X)
   loss = loss_fn(pred_y, output_Y)
   if values % 100 == 99:
      print(values, loss.item())

99 698.3545532226562
199 698.3545532226562
299 698.3545532226562
399 698.3545532226562
499 698.3545532226562

Here we are computing the predicted y by passing input_X to the model, after that computing the loss and then printing it.

Step 8 - Zero all gradients

optim.zero_grad()

Here before the backward pass we must zero all the gradients for the variables it will update which are nothing but the learnable weights of the model.

Step 9 - Backward pass

loss.backward()

Here we are computing the gradients of the loss w.r.t the model parameters.

Step 10 - Call step function

optim.step()

Here we are calling the step function on an optimizer which will makes an update to its parameters.

{"mode":"full","isActive":false}

What Users are saying..

profile image

Jingwei Li

Graduate Research assistance at Stony Brook University
linkedin profile url

ProjectPro is an awesome platform that helps me learn much hands-on industrial experience with a step-by-step walkthrough of projects. There are two primary paths to learn: Data Science and Big Data.... Read More

Relevant Projects

Detectron2 Object Detection and Segmentation Example Python
Object Detection using Detectron2 - Build a Dectectron2 model to detect the zones and inhibitions in antibiogram images.

Build CNN Image Classification Models for Real Time Prediction
Image Classification Project to build a CNN model in Python that can classify images into social security cards, driving licenses, and other key identity information.

Deep Learning Project- Real-Time Fruit Detection using YOLOv4
In this deep learning project, you will learn to build an accurate, fast, and reliable real-time fruit detection system using the YOLOv4 object detection model for robotic harvesting platforms.

Time Series Forecasting Project-Building ARIMA Model in Python
Build a time series ARIMA model in Python to forecast the use of arrival rate density to support staffing decisions at call centres.

Build a Churn Prediction Model using Ensemble Learning
Learn how to build ensemble machine learning models like Random Forest, Adaboost, and Gradient Boosting for Customer Churn Prediction using Python

Isolation Forest Model and LOF for Anomaly Detection in Python
Credit Card Fraud Detection Project - Build an Isolation Forest Model and Local Outlier Factor (LOF) in Python to identify fraudulent credit card transactions.

AWS MLOps Project to Deploy a Classification Model [Banking]
In this AWS MLOps project, you will learn how to deploy a classification model using Flask on AWS.

Build a Hybrid Recommender System in Python using LightFM
In this Recommender System project, you will build a hybrid recommender system in Python using LightFM .

Build ARCH and GARCH Models in Time Series using Python
In this Project we will build an ARCH and a GARCH model using Python

A/B Testing Approach for Comparing Performance of ML Models
The objective of this project is to compare the performance of BERT and DistilBERT models for building an efficient Question and Answering system. Using A/B testing approach, we explore the effectiveness and efficiency of both models and determine which one is better suited for Q&A tasks.