How to visualise XGBoost feature importance in R?

This recipe helps you visualise XGBoost feature importance in R

Recipe Objective

Classification and regression are supervised learning models that can be solved using algorithms like linear regression / logistics regression, decision tree, etc. But these are not competitive in terms of producing a good prediction accuracy.Ensemble techniques, on the other hand, create multiple models and combine them into one to produce effective results.

Bagging, boosting, random forest, are different types of ensemble techniques. Boosting is a sequential ensemble technique in which the model is improved using the information from previously grown weaker models. This process is continued for multiple iterations until a final model is built which will predict a more accurate outcome. ​

There are 3 types of boosting techniques: ​

  1. Adaboost
  2. Gradient Descent.
  3. Xgboost

Recently, researchers and enthusiasts have started using ensemble techniques like XGBoost to win data science competitions and hackathons. It outperforms algorithms such as Random Forest and Gadient Boosting in terms of speed as well as accuracy when performed on structured data. ​

XGBoost uses ensemble model which is based on Decision tree. A simple decision tree is considered to be a weak learner. The algorithm build sequential decision trees were each tree corrects the error occuring in the previous one until a condition is met.

In this recipe, we will discuss how to build and visualise XGBoost Tree.. ​

STEP 1: Importing Necessary Libraries

library(caret) # for general data preparation and model fitting library(rpart.plot) library(tidyverse)

STEP 2: Read a csv file and explore the data

The dataset attached contains the data of 160 different bags associated with ABC industries.

The bags have certain attributes which are described below: ​

  1. Height – The height of the bag
  2. Width – The width of the bag
  3. Length – The length of the bag
  4. Weight – The weight the bag can carry
  5. Weight1 – Weight the bag can carry after expansion

The company now wants to predict the cost they should set for a new variant of these kinds of bags. ​

data <- read.csv("R_357_Data_1.csv") glimpse(data)

Rows: 159
Columns: 6
$ Cost     242, 290, 340, 363, 430, 450, 500, 390, 450, 500, 475, 500,...
$ Weight   23.2, 24.0, 23.9, 26.3, 26.5, 26.8, 26.8, 27.6, 27.6, 28.5,...
$ Weight1  25.4, 26.3, 26.5, 29.0, 29.0, 29.7, 29.7, 30.0, 30.0, 30.7,...
$ Length   30.0, 31.2, 31.1, 33.5, 34.0, 34.7, 34.5, 35.0, 35.1, 36.2,...
$ Height   11.5200, 12.4800, 12.3778, 12.7300, 12.4440, 13.6024, 14.17...
$ Width    4.0200, 4.3056, 4.6961, 4.4555, 5.1340, 4.9274, 5.2785, 4.6...

summary(data) # returns the statistical summary of the data columns

Cost            Weight         Weight1          Length     
 Min.   :   0.0   Min.   : 7.50   Min.   : 8.40   Min.   : 8.80  
 1st Qu.: 120.0   1st Qu.:19.05   1st Qu.:21.00   1st Qu.:23.15  
 Median : 273.0   Median :25.20   Median :27.30   Median :29.40  
 Mean   : 398.3   Mean   :26.25   Mean   :28.42   Mean   :31.23  
 3rd Qu.: 650.0   3rd Qu.:32.70   3rd Qu.:35.50   3rd Qu.:39.65  
 Max.   :1650.0   Max.   :59.00   Max.   :63.40   Max.   :68.00  
     Height           Width      
 Min.   : 1.728   Min.   :1.048  
 1st Qu.: 5.945   1st Qu.:3.386  
 Median : 7.786   Median :4.248  
 Mean   : 8.971   Mean   :4.417  
 3rd Qu.:12.366   3rd Qu.:5.585  
 Max.   :18.957   Max.   :8.142   

dim(data)

159 6

STEP 3: Train Test Split

# createDataPartition() function from the caret package to split the original dataset into a training and testing set and split data into training (80%) and testing set (20%) parts = createDataPartition(data$Cost, p = .8, list = F) train = data[parts, ] test = data[-parts, ] #define predictor and response variables in training set train_x = data.matrix(train[, -1]) train_y = train[,1] #define predictor and response variables in testing set test_x = data.matrix(test[, -1]) test_y = test[, 1] #define final training and testing sets xgb_train = xgb.DMatrix(data = train_x, label = train_y) xgb_test = xgb.DMatrix(data = test_x, label = test_y)

STEP 4: Create a xgboost model

Now, we will fit and train our model using the xgb.train() function, which will result in corresponding training and testing root mean squared error for each round.

Here, the max.depth paramter deermines how deep the tree should grow, we choose a value of 3.

#defining a watchlist watchlist = list(train=xgb_train, test=xgb_test) #fit XGBoost model and display training and testing data at each iteartion model = xgb.train(data = xgb_train, max.depth = 3, watchlist=watchlist, nrounds = 100)

[1]	train-rmse:374.441406	test-rmse:481.788391 
[2]	train-rmse:274.574158	test-rmse:377.512909 
[3]	train-rmse:204.863098	test-rmse:306.634033 
[4]	train-rmse:155.649658	test-rmse:251.804932 
[5]	train-rmse:119.886559	test-rmse:206.584793 
[6]	train-rmse:94.443649	test-rmse:170.362732 
[7]	train-rmse:76.098549	test-rmse:157.283279 
[8]	train-rmse:63.038189	test-rmse:148.384521 
[9]	train-rmse:53.171177	test-rmse:142.591125 
[10]	train-rmse:46.219536	test-rmse:126.492058 
[11]	train-rmse:41.068180	test-rmse:112.861725 
[12]	train-rmse:37.273392	test-rmse:101.792809 
[13]	train-rmse:33.991714	test-rmse:99.646431 
[14]	train-rmse:31.665110	test-rmse:91.611916 
[15]	train-rmse:29.955919	test-rmse:84.864738 
[16]	train-rmse:28.531353	test-rmse:79.398239 
[17]	train-rmse:27.040276	test-rmse:74.698051 
[18]	train-rmse:26.302597	test-rmse:70.936241 
[19]	train-rmse:25.201057	test-rmse:67.750641 
[20]	train-rmse:24.487757	test-rmse:65.076195 
[21]	train-rmse:23.867445	test-rmse:65.166847 
[22]	train-rmse:22.876081	test-rmse:63.112698 
[23]	train-rmse:22.164562	test-rmse:61.523403 
[24]	train-rmse:21.816034	test-rmse:61.467430 
[25]	train-rmse:21.125587	test-rmse:61.402748 
[26]	train-rmse:20.957186	test-rmse:60.343128 
[27]	train-rmse:20.365843	test-rmse:60.348598 
[28]	train-rmse:20.168547	test-rmse:59.282814 
[29]	train-rmse:18.995090	test-rmse:58.969128 
[30]	train-rmse:18.819603	test-rmse:59.020538 
[31]	train-rmse:18.699118	test-rmse:58.379250 
[32]	train-rmse:17.504850	test-rmse:57.781509 
[33]	train-rmse:17.387026	test-rmse:57.645771 
[34]	train-rmse:17.037064	test-rmse:57.125183 
[35]	train-rmse:16.668007	test-rmse:56.830990 
[36]	train-rmse:16.044168	test-rmse:56.780052 
[37]	train-rmse:15.536475	test-rmse:56.567234 
[38]	train-rmse:15.433763	test-rmse:56.546337 
[39]	train-rmse:15.098138	test-rmse:56.664021 
[40]	train-rmse:14.819264	test-rmse:56.322807 
[41]	train-rmse:14.625785	test-rmse:56.316051 
[42]	train-rmse:14.350323	test-rmse:56.248844 
[43]	train-rmse:14.131385	test-rmse:56.189671 
[44]	train-rmse:13.516161	test-rmse:56.011814 
[45]	train-rmse:13.048274	test-rmse:56.140182 
[46]	train-rmse:12.758994	test-rmse:55.925411 
[47]	train-rmse:12.444994	test-rmse:56.098057 
[48]	train-rmse:12.082834	test-rmse:56.063778 
[49]	train-rmse:11.696443	test-rmse:56.002361 
[50]	train-rmse:11.560493	test-rmse:56.020744 
[51]	train-rmse:11.102805	test-rmse:56.114948 
[52]	train-rmse:10.627243	test-rmse:56.106552 
[53]	train-rmse:10.547875	test-rmse:56.181263 
[54]	train-rmse:10.363978	test-rmse:55.970352 
[55]	train-rmse:10.133872	test-rmse:56.034210 
[56]	train-rmse:9.734212	test-rmse:56.160725 
[57]	train-rmse:9.508077	test-rmse:56.177059 
[58]	train-rmse:9.202065	test-rmse:56.142998 
[59]	train-rmse:8.973363	test-rmse:56.266232 
[60]	train-rmse:8.868542	test-rmse:55.894310 
[61]	train-rmse:8.712481	test-rmse:55.797413 
[62]	train-rmse:8.450444	test-rmse:55.796597 
[63]	train-rmse:8.261618	test-rmse:55.789951 
[64]	train-rmse:8.081842	test-rmse:55.639320 
[65]	train-rmse:7.938920	test-rmse:55.682808 
[66]	train-rmse:7.682938	test-rmse:55.756508 
[67]	train-rmse:7.553942	test-rmse:55.836765 
[68]	train-rmse:7.432102	test-rmse:55.685822 
[69]	train-rmse:7.294747	test-rmse:55.697899 
[70]	train-rmse:7.103888	test-rmse:55.749569 
[71]	train-rmse:6.905044	test-rmse:55.763145 
[72]	train-rmse:6.753871	test-rmse:55.844006 
[73]	train-rmse:6.690207	test-rmse:55.758812 
[74]	train-rmse:6.488435	test-rmse:55.740368 
[75]	train-rmse:6.299417	test-rmse:55.737957 
[76]	train-rmse:6.090727	test-rmse:55.710434 
[77]	train-rmse:5.966695	test-rmse:55.743229 
[78]	train-rmse:5.857632	test-rmse:55.720600 
[79]	train-rmse:5.828579	test-rmse:55.569942 
[80]	train-rmse:5.622557	test-rmse:55.612438 
[81]	train-rmse:5.563855	test-rmse:55.506058 
[82]	train-rmse:5.381149	test-rmse:55.447449 
[83]	train-rmse:5.306352	test-rmse:55.385094 
[84]	train-rmse:5.159195	test-rmse:55.371307 
[85]	train-rmse:5.009599	test-rmse:55.202850 
[86]	train-rmse:4.988478	test-rmse:55.135273 
[87]	train-rmse:4.858966	test-rmse:55.196877 
[88]	train-rmse:4.761164	test-rmse:55.197235 
[89]	train-rmse:4.649740	test-rmse:55.199398 
[90]	train-rmse:4.545322	test-rmse:55.266251 
[91]	train-rmse:4.471013	test-rmse:55.323376 
[92]	train-rmse:4.442612	test-rmse:55.336811 
[93]	train-rmse:4.399715	test-rmse:55.298866 
[94]	train-rmse:4.289005	test-rmse:55.273613 
[95]	train-rmse:4.196774	test-rmse:55.273048 
[96]	train-rmse:4.097926	test-rmse:55.231594 
[97]	train-rmse:3.942547	test-rmse:55.206097 
[98]	train-rmse:3.923210	test-rmse:55.145107 
[99]	train-rmse:3.835154	test-rmse:55.166672 
[100]	train-rmse:3.761758	test-rmse:55.160030 

#define final model model_xgboost = xgboost(data = xgb_train, max.depth = 3, nrounds = 86, verbose = 0) summary(model_xgboost)

               Length Class              Mode       
handle             1  xgb.Booster.handle externalptr
raw            91316  -none-             raw        
niter              1  -none-             numeric    
evaluation_log     2  data.table         list       
call              14  -none-             call       
params             2  -none-             list       
callbacks          1  -none-             list       
feature_names      5  -none-             character  
nfeatures          1  -none-             numeric    

STEP 5: Visualising xgboost feature importances

We will use xgb.importance(colnames, model = ) to get the importance matrix

# Compute feature importance matrix importance_matrix = xgb.importance(colnames(xgb_train), model = model_xgboost) importance_matrix

Feature	Gain		Cover		Frequency
Width	0.636898215	0.26837467	0.25553320
Length	0.272275966	0.17613034	0.16498994
Weight	0.069464120	0.22846068	0.26760563
Height	0.016696726	0.30477575	0.28370221
Weight1	0.004664973	0.02225856	0.02816901

# Nice graph xgb.plot.importance(importance_matrix[1:5,])

What Users are saying..

profile image

Jingwei Li

Graduate Research assistance at Stony Brook University
linkedin profile url

ProjectPro is an awesome platform that helps me learn much hands-on industrial experience with a step-by-step walkthrough of projects. There are two primary paths to learn: Data Science and Big Data.... Read More

Relevant Projects

Build a Text Classification Model with Attention Mechanism NLP
In this NLP Project, you will learn to build a multi class text classification model with attention mechanism.

Build an Image Segmentation Model using Amazon SageMaker
In this Machine Learning Project, you will learn to implement the UNet Architecture and build an Image Segmentation Model using Amazon SageMaker

Deploy Transformer-BART Model on Paperspace Cloud
In this MLOps Project you will learn how to deploy a Tranaformer BART Model for Abstractive Text Summarization on Paperspace Private Cloud

LLM Project to Build and Fine Tune a Large Language Model
In this LLM project for beginners, you will learn to build a knowledge-grounded chatbot using LLM's and learn how to fine tune it.

Build Customer Propensity to Purchase Model in Python
In this machine learning project, you will learn to build a machine learning model to estimate customer propensity to purchase.

Many-to-One LSTM for Sentiment Analysis and Text Generation
In this LSTM Project , you will build develop a sentiment detection model using many-to-one LSTMs for accurate prediction of sentiment labels in airline text reviews. Additionally, we will also train many-to-one LSTMs on 'Alice's Adventures in Wonderland' to generate contextually relevant text.

Predictive Analytics Project for Working Capital Optimization
In this Predictive Analytics Project, you will build a model to accurately forecast the timing of customer and supplier payments for optimizing working capital.

Build CNN Image Classification Models for Real Time Prediction
Image Classification Project to build a CNN model in Python that can classify images into social security cards, driving licenses, and other key identity information.

MLOps Project to Build Search Relevancy Algorithm with SBERT
In this MLOps SBERT project you will learn to build and deploy an accurate and scalable search algorithm on AWS using SBERT and ANNOY to enhance search relevancy in news articles.

Build a CNN Model with PyTorch for Image Classification
In this deep learning project, you will learn how to build an Image Classification Model using PyTorch CNN