How to use pretrained torch models for classification

This recipe helps you use pretrained torch models for classification

Recipe Objective

How to use pre-trained torch models for classification?

This is achieved by using torchvision.models package, This will contain the definitions of models for addressing various tasks. Including the classification like: Image classification, person keypoint detection and video classification, instance segmentation, semantic segmentation, object detection. There are various classification models available:
AlexNet
VGG
ResNet
SqueezeNet
DenseNet
Inception v3
GoogLeNet
ShuffleNet v2
MobileNet v2
ResNeXt
Wide ResNet
MNASNet

Learn to Implement Deep Learning Techniques for Medical Image Segmentation

Step 1 - Import library

from torchvision import models
from torchvision import transforms
from PIL import Image import torch

Step 2 - directory of models

dir(models)

['AlexNet',
 'DenseNet',
 'GoogLeNet',
 'GoogLeNetOutputs',
 'Inception3',
 'InceptionOutputs',
 'MNASNet',
 'MobileNetV2',
 'ResNet',
 'ShuffleNetV2',
 'SqueezeNet',
 'VGG',
 '_GoogLeNetOutputs',
 '_InceptionOutputs',
 '__builtins__',
 '__cached__',
 '__doc__',
 '__file__',
 '__loader__',
 '__name__',
 '__package__',
 '__path__',
 '__spec__',
 '_utils',
 'alexnet',
 'densenet',
 'densenet121',
 'densenet161',
 'densenet169',
 'densenet201',
 'detection',
 'googlenet',
 'inception',
 'inception_v3',
 'mnasnet',
 'mnasnet0_5',
 'mnasnet0_75',
 'mnasnet1_0',
 'mnasnet1_3',
 'mobilenet',
 'mobilenet_v2',
 'quantization',
 'resnet',
 'resnet101',
 'resnet152',
 'resnet18',
 'resnet34',
 'resnet50',
 'resnext101_32x8d',
 'resnext50_32x4d',
 'segmentation',
 'shufflenet_v2_x0_5',
 'shufflenet_v2_x1_0',
 'shufflenet_v2_x1_5',
 'shufflenet_v2_x2_0',
 'shufflenetv2',
 'squeezenet',
 'squeezenet1_0',
 'squeezenet1_1',
 'utils',
 'vgg',
 'vgg11',
 'vgg11_bn',
 'vgg13',
 'vgg13_bn',
 'vgg16',
 'vgg16_bn',
 'vgg19',
 'vgg19_bn',
 'video',
 'wide_resnet101_2',
 'wide_resnet50_2']

Step 3 - Load the model

classification_alexnet = models.alexnet(pretrained=True)
print(classification_alexnet)

Downloading: "https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth" to /root/.cache/torch/hub/checkpoints/alexnet-owt-4df8aa71.pth
100%
233M/233M [00:02<00:00, 88.8MB/s]

AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
    (2): ReLU(inplace=True)
    (3): Dropout(p=0.5, inplace=False)
    (4): Linear(in_features=4096, out_features=4096, bias=True)
    (5): ReLU(inplace=True)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

Step 4 - Specify Image transformation

image_transforms = transform = transforms.Compose([transforms.Resize(256),
                                                    transforms.CenterCrop(224),
                                                    transforms.ToTensor(),
                                                     transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])

Here in the above code we are defining various parameters for image transformation which are:
-- transforms.Compose - In this we are defining a variable transform which is nothing but a combination of all the image transformations which is to be carried out on the input image.
-- transforms.Resize(256) - In this we are resizing the image with a pixels of 256x256.
-- transforms.CenterCrop(224) - This is for croping the image at a pixel of 224x224.
-- transforms.ToTensor() - This is for converting the image into pytorch tensor.
-- transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225] - In this by setting the mean and standard deviation to the specified values we are normalizing the image.

Step 5 - Load the image

flower_image = Image.open("/content/yellow-orange-starburst-flower-nature-jpg-192959431.jpg")
flower_image

Step 6 - Perprocess the image

transform_flower_image = image_transforms(flower_image)
transform_batch = torch.unsqueeze(transform_flower_image, 0)

Step 7 - Model Inference

classification_alexnet.eval() out = classification_alexnet(transform_batch)
print(out.shape)

torch.Size([1, 1000])

{"mode":"full","isActive":false}

What Users are saying..

profile image

Savvy Sahai

Data Science Intern, Capgemini
linkedin profile url

As a student looking to break into the field of data engineering and data science, one can get really confused as to which path to take. Very few ways to do it are Google, YouTube, etc. I was one of... Read More

Relevant Projects

Mastering A/B Testing: A Practical Guide for Production
In this A/B Testing for Machine Learning Project, you will gain hands-on experience in conducting A/B tests, analyzing statistical significance, and understanding the challenges of building a solution for A/B testing in a production environment.

PyTorch Project to Build a LSTM Text Classification Model
In this PyTorch Project you will learn how to build an LSTM Text Classification model for Classifying the Reviews of an App .

Loan Eligibility Prediction using Gradient Boosting Classifier
This data science in python project predicts if a loan should be given to an applicant or not. We predict if the customer is eligible for loan based on several factors like credit score and past history.

PyCaret Project to Build and Deploy an ML App using Streamlit
In this PyCaret Project, you will build a customer segmentation model with PyCaret and deploy the machine learning application using Streamlit.

Build a Hybrid Recommender System in Python using LightFM
In this Recommender System project, you will build a hybrid recommender system in Python using LightFM .

Create Your First Chatbot with RASA NLU Model and Python
Learn the basic aspects of chatbot development and open source conversational AI RASA to create a simple AI powered chatbot on your own.

Hands-On Approach to Causal Inference in Machine Learning
In this Machine Learning Project, you will learn to implement various causal inference techniques in Python to determine, how effective the sprinkler is in making the grass wet.

Build a Multi ClassText Classification Model using Naive Bayes
Implement the Naive Bayes Algorithm to build a multi class text classification model in Python.

Machine Learning project for Retail Price Optimization
In this machine learning pricing project, we implement a retail price optimization algorithm using regression trees. This is one of the first steps to building a dynamic pricing model.

Deploying Machine Learning Models with Flask for Beginners
In this MLOps on GCP project you will learn to deploy a sales forecasting ML Model using Flask.