What is SoftmaxCrossEntropyLoss in MXNET

This recipe explains what is SoftmaxCrossEntropyLoss in MXNET

Recipe Objective: What is SoftmaxCrossEntropyLoss in MXNet?

This recipe explains what is SoftmaxCrossEntropyLoss in MXNet.

Step 1: Importing library

Let us first import the necessary libraries.

import math
import mxnet as mx
import numpy as np
from mxnet import nd, autograd, gluon
from mxnet.gluon.data.vision import transforms

Step 2: Data Set

We'll use the MNIST data set to perform a set of operations. We'll load the data set using gluon.data.DataLoader().

train = gluon.data.DataLoader(gluon.data.vision.MNIST(train=True).transform_first(transforms.ToTensor()), 128, shuffle=True)

Step 3: Neural Network

We have built a neural network with two convolutional layers.

def network(net):
    with net.name_scope():
       net.add(gluon.nn.Conv2D(channels=10, kernel_size=1, activation='relu'))
       net.add(gluon.nn.MaxPool2D(pool_size=4, strides=4))
       net.add(gluon.nn.Conv2D(channels=20, kernel_size=1, activation='relu'))
       net.add(gluon.nn.MaxPool2D(pool_size=4, strides=4))
       net.add(gluon.nn.Flatten())
       net.add(gluon.nn.Dense(256, activation="relu"))
       net.add(gluon.nn.Dense(10))

       return net

Step 4: Loss with Softmax

To control the ultimate performance of the network and speed of convergence while training a neural network, the essential part is setting the learning rate for SGD (Stochastic Gradient Descent). By keeping the learning rate constant throughout the training process is the most straightforward strategy. By keeping the learning rate value small, the optimizer finds reasonable solutions, but this comes at the expense of limiting the initial speed of convergence. Changing the learning rate over time can resolve this.
SoftmaxCrossEntropyLoss() computes the softmax cross entropy loss. To avoid numerical instabilities, the softmax_cross_entropy module provides a single operator with softmax and cross-entropy fused.

def modeltrain(model):
    model.initialize()     iterations = math.ceil(len(train) / 128)
    steps = [s*iterations for s in [1,2,3]]
    softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()
    learning_rate = mx.lr_scheduler.MultiFactorScheduler(step=steps, factor=0.1)
    cnt = mx.optimizer.SGD(learning_rate=0.03, lr_scheduler=learning_rate)
    trainer = mx.gluon.Trainer(params=net.collect_params(), optimizer=cnt)
    for epoch in range(1):
       for batch_num, (data, label) in enumerate(train):
          data = data.as_in_context(mx.cpu())
          label = label.as_in_context(mx.cpu())
          with autograd.record():
             output = model(data)
             loss = softmax_cross_entropy(output, label)
          loss.backward()
          trainer.step(data.shape[0])
          if batch_num % 50 == 0:
             curr_loss = nd.mean(loss).asscalar()
             print("Epoch: %d; Batch %d; Loss %f" % (epoch, batch_num, curr_loss))

What Users are saying..

profile image

Ray han

Tech Leader | Stanford / Yale University
linkedin profile url

I think that they are fantastic. I attended Yale and Stanford and have worked at Honeywell,Oracle, and Arthur Andersen(Accenture) in the US. I have taken Big Data and Hadoop,NoSQL, Spark, Hadoop... Read More

Relevant Projects

Build a Multi Class Image Classification Model Python using CNN
This project explains How to build a Sequential Model that can perform Multi Class Image Classification in Python using CNN

End-to-End ML Model Monitoring using Airflow and Docker
In this MLOps Project, you will learn to build an end to end pipeline to monitor any changes in the predictive power of model or degradation of data.

Image Classification Model using Transfer Learning in PyTorch
In this PyTorch Project, you will build an image classification model in PyTorch using the ResNet pre-trained model.

Digit Recognition using CNN for MNIST Dataset in Python
In this deep learning project, you will build a convolutional neural network using MNIST dataset for handwritten digit recognition.

Walmart Sales Forecasting Data Science Project
Data Science Project in R-Predict the sales for each department using historical markdown data from the Walmart dataset containing data of 45 Walmart stores.

Deep Learning Project- Real-Time Fruit Detection using YOLOv4
In this deep learning project, you will learn to build an accurate, fast, and reliable real-time fruit detection system using the YOLOv4 object detection model for robotic harvesting platforms.

Time Series Analysis with Facebook Prophet Python and Cesium
Time Series Analysis Project - Use the Facebook Prophet and Cesium Open Source Library for Time Series Forecasting in Python

Deploy Transformer BART Model for Text summarization on GCP
Learn to Deploy a Machine Learning Model for the Abstractive Text Summarization on Google Cloud Platform (GCP)

OpenCV Project for Beginners to Learn Computer Vision Basics
In this OpenCV project, you will learn computer vision basics and the fundamentals of OpenCV library using Python.

Build a Music Recommendation Algorithm using KKBox's Dataset
Music Recommendation Project using Machine Learning - Use the KKBox dataset to predict the chances of a user listening to a song again after their very first noticeable listening event.