What is Trainer in transformers?

This recipe explains what is Trainer in transformers.

Recipe Objective - What is Trainer in transformers?

The Trainer and TFTrainer classes provide APIs for functionally complete training in most standard use cases.
Both Trainer and TFTrainer contain basic training loops that support the above functions. To inject custom behaviors, you can subclass them and override the following methods:
1. get_train_dataloader/get_train_tfdataset – Creates the training DataLoader (PyTorch) or TF Dataset.
2. get_eval_dataloader/get_eval_tfdataset – Creates the evaluation DataLoader (PyTorch) or TF Dataset.
3. get_test_dataloader/get_test_tfdataset – Creates the test DataLoader (PyTorch) or TF Dataset.
4. log – Logs information on the various objects watching training.
5. create_optimizer_and_scheduler – Sets up the optimizer and learning rate scheduler if they were not passed at init. Note, that you simply also can subclass or override the create_optimizer and create_scheduler methods separately.
6. create_optimizer – Sets up the optimizer if it wasn’t passed at init.
7. create_scheduler – Sets up the learning rate scheduler if it wasn’t passed at init.
8. compute_loss - Computes the loss on a batch of training inputs.
9. training_step – Performs a training step.
10. prediction_step – Performs an evaluation/test step.
11. run_model (TensorFlow only) – Basic pass through the model.
12. evaluate – Runs an evaluation loop and returns metrics.
13. predict – Returns predictions (with metrics if labels are available) on a test set.

For more related projects -

/projects/data-science-projects/deep-learning-projects
/projects/data-science-projects/tensorflow-projects

Example -

Let's see how to customize Trainer using a custom loss function for multi-label classification:

# Importing libraries
from torch import nn
from transformers import Trainer

# Customize trainer using custom loss function
class MultilabelTrainer(Trainer):
 def compute_loss(self, model, inputs, return_outputs=False):
  labels = inputs.pop("labels")
  outputs = model(**inputs)
  logits = outputs.logits
  loss_fct = nn.BCEWithLogitsLoss()
  loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.float().view(-1, self.model.config.num_labels))
  return (loss, outputs) if return_outputs else loss

In this way, we can customize trainer using some functions in transformers.

What Users are saying..

profile image

Jingwei Li

Graduate Research assistance at Stony Brook University
linkedin profile url

ProjectPro is an awesome platform that helps me learn much hands-on industrial experience with a step-by-step walkthrough of projects. There are two primary paths to learn: Data Science and Big Data.... Read More

Relevant Projects

Learn to Build a Neural network from Scratch using NumPy
In this deep learning project, you will learn to build a neural network from scratch using NumPy

CycleGAN Implementation for Image-To-Image Translation
In this GAN Deep Learning Project, you will learn how to build an image to image translation model in PyTorch with Cycle GAN.

Build a Face Recognition System in Python using FaceNet
In this deep learning project, you will build your own face recognition system in Python using OpenCV and FaceNet by extracting features from an image of a person's face.

Ensemble Machine Learning Project - All State Insurance Claims Severity Prediction
In this ensemble machine learning project, we will predict what kind of claims an insurance company will get. This is implemented in python using ensemble machine learning algorithms.

NLP Project to Build a Resume Parser in Python using Spacy
Use the popular Spacy NLP python library for OCR and text classification to build a Resume Parser in Python.

Multilabel Classification Project for Predicting Shipment Modes
Multilabel Classification Project to build a machine learning model that predicts the appropriate mode of transport for each shipment, using a transport dataset with 2000 unique products. The project explores and compares four different approaches to multilabel classification, including naive independent models, classifier chains, natively multilabel models, and multilabel to multiclass approaches.

Deploy Transformer BART Model for Text summarization on GCP
Learn to Deploy a Machine Learning Model for the Abstractive Text Summarization on Google Cloud Platform (GCP)

OpenCV Project to Master Advanced Computer Vision Concepts
In this OpenCV project, you will learn to implement advanced computer vision concepts and algorithms in OpenCV library using Python.

Skip Gram Model Python Implementation for Word Embeddings
Skip-Gram Model word2vec Example -Learn how to implement the skip gram algorithm in NLP for word embeddings on a set of documents.

Azure Deep Learning-Deploy RNN CNN models for TimeSeries
In this Azure MLOps Project, you will learn to perform docker-based deployment of RNN and CNN Models for Time Series Forecasting on Azure Cloud.