How to use MNIST using trainer in chainer

This recipe helps you use MNIST using trainer in chainer

Recipe Objective - MNIST using trainer in chainer?

With the help of "trainer", we don't have to write the training loop explicitly anymore. Chainer provides many extensions which can be very helpful with trainer.

We can visualize the results, evaluate the model, store and manage log files more easily.

We will use MNIST dataset to create our neural net using trainer.

import math
import numpy as np
import chainer
from chainer import backend
from chainer import backends
from chainer.backends import cuda
from chainer import Function, FunctionNode, gradient_check, report, training, utils, Variable
from chainer import datasets, initializers, iterators, optimizers, serializers
from chainer import Link, Chain, ChainList
import chainer.functions as F
import chainer.links as L
from chainer.training import extensions

1.Preparing MNIST dataset:-


from chainer.datasets import mnist
train, test = mnist.get_mnist()

2.Preparing the dataset iterations:-


batchsize = 120
train_iter = iterators.SerialIterator(train, batchsize)
test_iter = iterators.SerialIterator(test, batchsize, False, False)

3. Preparing the model:-

class MNST(Chain):

  def __init__(self, n_mid_units=90, n_out=10):
   super(MNST, self).__init__()
   with self.init_scope():
   self.l1 = L.Linear(None, n_mid_units)
   self.l2 = L.Linear(None, n_mid_units)
   self.l3 = L.Linear(None, n_out)

  def forward(self, x):
   h1 = F.relu(self.l1(x))
   h2 = F.relu(self.l2(h1))
   return self.l3(h2)

gpu_id = -1 # Set to 0 if you use GPU

model = MNST()
if gpu_id >= 0:
  model.to_gpu(gpu_id)

4. Prepare the Updater:-

max_epoch = 10

# Wrap your model by Classifier and include the process of loss calculation within your model.
# Since we do not specify a loss function here, the default 'softmax_cross_entropy' is used.
model = L.Classifier(model)

# selection of your optimizing method
optimizer = optimizers.MomentumSGD()

# Give the optimizer a reference to the model
optimizer.setup(model)

# Get an updater that uses the Iterator and Optimizer
updater = training.updaters.StandardUpdater(train_iter, optimizer, device=gpu_id)

5. Setup Trainer:-


# Setup a Trainer
trainer = training.Trainer(updater, (max_epoch, 'epoch'), out='mnist_result')

6. Add Extensions to the Trainer object:-


"LogReport()" - for saving the log files automatically.
"PrintReport()" - to display the training information to the terminal.
"PlotReport()" - to visualize the loss progress by plotting a graph and save it as an image.
"snapshot_object()" - automatically serialize the state.
"ProgressBar()" - to display a progress bar to the terminal.
"DumpGraph()" - to save the model architecture as a Graphviz’s dot file.

from chainer.training import extensions

trainer.extend(extensions.LogReport())
trainer.extend(extensions.snapshot(filename='snapshot_epoch-{.updater.epoch}'))
trainer.extend(extensions.snapshot_object(model.predictor, filename='model_epoch-{.updater.epoch}'))
trainer.extend(extensions.Evaluator(test_iter, model, device=gpu_id))
trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'main/accuracy', 'validation/main/loss', 'validation/main/accuracy', 'elapsed_time']))
trainer.extend(extensions.PlotReport(['main/loss', 'validation/main/loss'], x_key='epoch', file_name='loss.png'))
trainer.extend(extensions.PlotReport(['main/accuracy', 'validation/main/accuracy'], x_key='epoch', file_name='accuracy.png'))
trainer.extend(extensions.DumpGraph('main/loss'))

What Users are saying..

profile image

Anand Kumpatla

Sr Data Scientist @ Doubleslash Software Solutions Pvt Ltd
linkedin profile url

ProjectPro is a unique platform and helps many people in the industry to solve real-life problems with a step-by-step walkthrough of projects. A platform with some fantastic resources to gain... Read More

Relevant Projects

PyTorch Project to Build a LSTM Text Classification Model
In this PyTorch Project you will learn how to build an LSTM Text Classification model for Classifying the Reviews of an App .

Build CI/CD Pipeline for Machine Learning Projects using Jenkins
In this project, you will learn how to create a CI/CD pipeline for a search engine application using Jenkins.

FEAST Feature Store Example for Scaling Machine Learning
FEAST Feature Store Example- Learn to use FEAST Feature Store to manage, store, and discover features for customer churn prediction machine learning project.

Build a Similar Images Finder with Python, Keras, and Tensorflow
Build your own image similarity application using Python to search and find images of products that are similar to any given product. You will implement the K-Nearest Neighbor algorithm to find products with maximum similarity.

Build Deep Autoencoders Model for Anomaly Detection in Python
In this deep learning project , you will build and deploy a deep autoencoders model using Flask.

Learn to Build Generative Models Using PyTorch Autoencoders
In this deep learning project, you will learn how to build a Generative Model using Autoencoders in PyTorch

NLP Project on LDA Topic Modelling Python using RACE Dataset
Use the RACE dataset to extract a dominant topic from each document and perform LDA topic modeling in python.

Build Piecewise and Spline Regression Models in Python
In this Regression Project, you will learn how to build a piecewise and spline regression model from scratch in Python to predict the points scored by a sports team.

End-to-End Speech Emotion Recognition Project using ANN
Speech Emotion Recognition using RAVDESS Audio Dataset - Build an Artificial Neural Network Model to Classify Audio Data into various Emotions like Sad, Happy, Angry, and Neutral

Build a Autoregressive and Moving Average Time Series Model
In this time series project, you will learn to build Autoregressive and Moving Average Time Series Models to forecast future readings, optimize performance, and harness the power of predictive analytics for sensor data.