How to visualise XGBoost tree in R?

This recipe helps you visualise XGBoost tree in R

Recipe Objective

Classification and regression are supervised learning models that can be solved using algorithms like linear regression / logistics regression, decision tree, etc. But these are not competitive in terms of producing a good prediction accuracy.Ensemble techniques, on the other hand, create multiple models and combine them into one to produce effective results.

Bagging, boosting, random forest, are different types of ensemble techniques. Boosting is a sequential ensemble technique in which the model is improved using the information from previously grown weaker models. This process is continued for multiple iterations until a final model is built which will predict a more accurate outcome. ​

There are 3 types of boosting techniques: ​

  1. Adaboost
  2. Gradient Descent.
  3. Xgboost

Recently, researchers and enthusiasts have started using ensemble techniques like XGBoost to win data science competitions and hackathons. It outperforms algorithms such as Random Forest and Gadient Boosting in terms of speed as well as accuracy when performed on structured data. ​

XGBoost uses ensemble model which is based on Decision tree. A simple decision tree is considered to be a weak learner. The algorithm build sequential decision trees were each tree corrects the error occuring in the previous one until a condition is met.

In this recipe, we will discuss how to build and visualise XGBoost Tree.. ​

STEP 1: Importing Necessary Libraries

library(caret) # for general data preparation and model fitting library(rpart.plot) library(tidyverse)

STEP 2: Read a csv file and explore the data

The dataset attached contains the data of 160 different bags associated with ABC industries.

The bags have certain attributes which are described below: ​

  1. Height – The height of the bag
  2. Width – The width of the bag
  3. Length – The length of the bag
  4. Weight – The weight the bag can carry
  5. Weight1 – Weight the bag can carry after expansion

The company now wants to predict the cost they should set for a new variant of these kinds of bags. ​

data <- read.csv("R_356_Data_1.csv") glimpse(data)

Rows: 159
Columns: 6
$ Cost     242, 290, 340, 363, 430, 450, 500, 390, 450, 500, 475, 500,...
$ Weight   23.2, 24.0, 23.9, 26.3, 26.5, 26.8, 26.8, 27.6, 27.6, 28.5,...
$ Weight1  25.4, 26.3, 26.5, 29.0, 29.0, 29.7, 29.7, 30.0, 30.0, 30.7,...
$ Length   30.0, 31.2, 31.1, 33.5, 34.0, 34.7, 34.5, 35.0, 35.1, 36.2,...
$ Height   11.5200, 12.4800, 12.3778, 12.7300, 12.4440, 13.6024, 14.17...
$ Width    4.0200, 4.3056, 4.6961, 4.4555, 5.1340, 4.9274, 5.2785, 4.6...

summary(data) # returns the statistical summary of the data columns

Cost            Weight         Weight1          Length     
 Min.   :   0.0   Min.   : 7.50   Min.   : 8.40   Min.   : 8.80  
 1st Qu.: 120.0   1st Qu.:19.05   1st Qu.:21.00   1st Qu.:23.15  
 Median : 273.0   Median :25.20   Median :27.30   Median :29.40  
 Mean   : 398.3   Mean   :26.25   Mean   :28.42   Mean   :31.23  
 3rd Qu.: 650.0   3rd Qu.:32.70   3rd Qu.:35.50   3rd Qu.:39.65  
 Max.   :1650.0   Max.   :59.00   Max.   :63.40   Max.   :68.00  
     Height           Width      
 Min.   : 1.728   Min.   :1.048  
 1st Qu.: 5.945   1st Qu.:3.386  
 Median : 7.786   Median :4.248  
 Mean   : 8.971   Mean   :4.417  
 3rd Qu.:12.366   3rd Qu.:5.585  
 Max.   :18.957   Max.   :8.142   

dim(data)

159 6

STEP 3: Train Test Split

# createDataPartition() function from the caret package to split the original dataset into a training and testing set and split data into training (80%) and testing set (20%) parts = createDataPartition(data$Cost, p = .8, list = F) train = data[parts, ] test = data[-parts, ] #define predictor and response variables in training set train_x = data.matrix(train[, -1]) train_y = train[,1] #define predictor and response variables in testing set test_x = data.matrix(test[, -1]) test_y = test[, 1] #define final training and testing sets xgb_train = xgb.DMatrix(data = train_x, label = train_y) xgb_test = xgb.DMatrix(data = test_x, label = test_y)

STEP 4: Create a xgboost model

Now, we will fit and train our model using the xgb.train() function, which will result in corresponding training and testing root mean squared error for each round.

Here, the max.depth paramter deermines how deep the tree should grow, we choose a value of 3.

#defining a watchlist watchlist = list(train=xgb_train, test=xgb_test) #fit XGBoost model and display training and testing data at each iteartion model = xgb.train(data = xgb_train, max.depth = 3, watchlist=watchlist, nrounds = 100)

[1]	train-rmse:374.441406	test-rmse:481.788391 
[2]	train-rmse:274.574158	test-rmse:377.512909 
[3]	train-rmse:204.863098	test-rmse:306.634033 
[4]	train-rmse:155.649658	test-rmse:251.804932 
[5]	train-rmse:119.886559	test-rmse:206.584793 
[6]	train-rmse:94.443649	test-rmse:170.362732 
[7]	train-rmse:76.098549	test-rmse:157.283279 
[8]	train-rmse:63.038189	test-rmse:148.384521 
[9]	train-rmse:53.171177	test-rmse:142.591125 
[10]	train-rmse:46.219536	test-rmse:126.492058 
[11]	train-rmse:41.068180	test-rmse:112.861725 
[12]	train-rmse:37.273392	test-rmse:101.792809 
[13]	train-rmse:33.991714	test-rmse:99.646431 
[14]	train-rmse:31.665110	test-rmse:91.611916 
[15]	train-rmse:29.955919	test-rmse:84.864738 
[16]	train-rmse:28.531353	test-rmse:79.398239 
[17]	train-rmse:27.040276	test-rmse:74.698051 
[18]	train-rmse:26.302597	test-rmse:70.936241 
[19]	train-rmse:25.201057	test-rmse:67.750641 
[20]	train-rmse:24.487757	test-rmse:65.076195 
[21]	train-rmse:23.867445	test-rmse:65.166847 
[22]	train-rmse:22.876081	test-rmse:63.112698 
[23]	train-rmse:22.164562	test-rmse:61.523403 
[24]	train-rmse:21.816034	test-rmse:61.467430 
[25]	train-rmse:21.125587	test-rmse:61.402748 
[26]	train-rmse:20.957186	test-rmse:60.343128 
[27]	train-rmse:20.365843	test-rmse:60.348598 
[28]	train-rmse:20.168547	test-rmse:59.282814 
[29]	train-rmse:18.995090	test-rmse:58.969128 
[30]	train-rmse:18.819603	test-rmse:59.020538 
[31]	train-rmse:18.699118	test-rmse:58.379250 
[32]	train-rmse:17.504850	test-rmse:57.781509 
[33]	train-rmse:17.387026	test-rmse:57.645771 
[34]	train-rmse:17.037064	test-rmse:57.125183 
[35]	train-rmse:16.668007	test-rmse:56.830990 
[36]	train-rmse:16.044168	test-rmse:56.780052 
[37]	train-rmse:15.536475	test-rmse:56.567234 
[38]	train-rmse:15.433763	test-rmse:56.546337 
[39]	train-rmse:15.098138	test-rmse:56.664021 
[40]	train-rmse:14.819264	test-rmse:56.322807 
[41]	train-rmse:14.625785	test-rmse:56.316051 
[42]	train-rmse:14.350323	test-rmse:56.248844 
[43]	train-rmse:14.131385	test-rmse:56.189671 
[44]	train-rmse:13.516161	test-rmse:56.011814 
[45]	train-rmse:13.048274	test-rmse:56.140182 
[46]	train-rmse:12.758994	test-rmse:55.925411 
[47]	train-rmse:12.444994	test-rmse:56.098057 
[48]	train-rmse:12.082834	test-rmse:56.063778 
[49]	train-rmse:11.696443	test-rmse:56.002361 
[50]	train-rmse:11.560493	test-rmse:56.020744 
[51]	train-rmse:11.102805	test-rmse:56.114948 
[52]	train-rmse:10.627243	test-rmse:56.106552 
[53]	train-rmse:10.547875	test-rmse:56.181263 
[54]	train-rmse:10.363978	test-rmse:55.970352 
[55]	train-rmse:10.133872	test-rmse:56.034210 
[56]	train-rmse:9.734212	test-rmse:56.160725 
[57]	train-rmse:9.508077	test-rmse:56.177059 
[58]	train-rmse:9.202065	test-rmse:56.142998 
[59]	train-rmse:8.973363	test-rmse:56.266232 
[60]	train-rmse:8.868542	test-rmse:55.894310 
[61]	train-rmse:8.712481	test-rmse:55.797413 
[62]	train-rmse:8.450444	test-rmse:55.796597 
[63]	train-rmse:8.261618	test-rmse:55.789951 
[64]	train-rmse:8.081842	test-rmse:55.639320 
[65]	train-rmse:7.938920	test-rmse:55.682808 
[66]	train-rmse:7.682938	test-rmse:55.756508 
[67]	train-rmse:7.553942	test-rmse:55.836765 
[68]	train-rmse:7.432102	test-rmse:55.685822 
[69]	train-rmse:7.294747	test-rmse:55.697899 
[70]	train-rmse:7.103888	test-rmse:55.749569 
[71]	train-rmse:6.905044	test-rmse:55.763145 
[72]	train-rmse:6.753871	test-rmse:55.844006 
[73]	train-rmse:6.690207	test-rmse:55.758812 
[74]	train-rmse:6.488435	test-rmse:55.740368 
[75]	train-rmse:6.299417	test-rmse:55.737957 
[76]	train-rmse:6.090727	test-rmse:55.710434 
[77]	train-rmse:5.966695	test-rmse:55.743229 
[78]	train-rmse:5.857632	test-rmse:55.720600 
[79]	train-rmse:5.828579	test-rmse:55.569942 
[80]	train-rmse:5.622557	test-rmse:55.612438 
[81]	train-rmse:5.563855	test-rmse:55.506058 
[82]	train-rmse:5.381149	test-rmse:55.447449 
[83]	train-rmse:5.306352	test-rmse:55.385094 
[84]	train-rmse:5.159195	test-rmse:55.371307 
[85]	train-rmse:5.009599	test-rmse:55.202850 
[86]	train-rmse:4.988478	test-rmse:55.135273 
[87]	train-rmse:4.858966	test-rmse:55.196877 
[88]	train-rmse:4.761164	test-rmse:55.197235 
[89]	train-rmse:4.649740	test-rmse:55.199398 
[90]	train-rmse:4.545322	test-rmse:55.266251 
[91]	train-rmse:4.471013	test-rmse:55.323376 
[92]	train-rmse:4.442612	test-rmse:55.336811 
[93]	train-rmse:4.399715	test-rmse:55.298866 
[94]	train-rmse:4.289005	test-rmse:55.273613 
[95]	train-rmse:4.196774	test-rmse:55.273048 
[96]	train-rmse:4.097926	test-rmse:55.231594 
[97]	train-rmse:3.942547	test-rmse:55.206097 
[98]	train-rmse:3.923210	test-rmse:55.145107 
[99]	train-rmse:3.835154	test-rmse:55.166672 
[100]	train-rmse:3.761758	test-rmse:55.160030 

#define final model model_xgboost = xgboost(data = xgb_train, max.depth = 3, nrounds = 86, verbose = 0) summary(model_xgboost)

               Length Class              Mode       
handle             1  xgb.Booster.handle externalptr
raw            91316  -none-             raw        
niter              1  -none-             numeric    
evaluation_log     2  data.table         list       
call              14  -none-             call       
params             2  -none-             list       
callbacks          1  -none-             list       
feature_names      5  -none-             character  
nfeatures          1  -none-             numeric    

STEP 5: Visualising an xgboost Tree

We will use xgb.plot.tree(model = ) to plot all the trees

# a plot with all the trees xgb.plot.tree(model = model_xgboost) # this seems to be a mess. Hence, we only stick to 1 tree at a time. The below code is to plot first tree and show its node ID xgb.plot.tree(model = model_xgboost, trees = 0, show_node_id = TRUE)

What Users are saying..

profile image

Savvy Sahai

Data Science Intern, Capgemini
linkedin profile url

As a student looking to break into the field of data engineering and data science, one can get really confused as to which path to take. Very few ways to do it are Google, YouTube, etc. I was one of... Read More

Relevant Projects

Build an Image Classifier for Plant Species Identification
In this machine learning project, we will use binary leaf images and extracted features, including shape, margin, and texture to accurately identify plant species using different benchmark classification techniques.

Hands-On Approach to Causal Inference in Machine Learning
In this Machine Learning Project, you will learn to implement various causal inference techniques in Python to determine, how effective the sprinkler is in making the grass wet.

Many-to-One LSTM for Sentiment Analysis and Text Generation
In this LSTM Project , you will build develop a sentiment detection model using many-to-one LSTMs for accurate prediction of sentiment labels in airline text reviews. Additionally, we will also train many-to-one LSTMs on 'Alice's Adventures in Wonderland' to generate contextually relevant text.

Learn to Build Generative Models Using PyTorch Autoencoders
In this deep learning project, you will learn how to build a Generative Model using Autoencoders in PyTorch

FEAST Feature Store Example for Scaling Machine Learning
FEAST Feature Store Example- Learn to use FEAST Feature Store to manage, store, and discover features for customer churn prediction machine learning project.

Census Income Data Set Project-Predict Adult Census Income
Use the Adult Income dataset to predict whether income exceeds 50K yr based oncensus data.

Deep Learning Project- Real-Time Fruit Detection using YOLOv4
In this deep learning project, you will learn to build an accurate, fast, and reliable real-time fruit detection system using the YOLOv4 object detection model for robotic harvesting platforms.

Stock Price Prediction Project using LSTM and RNN
Learn how to predict stock prices using RNN and LSTM models. Understand deep learning concepts and apply them to real-world financial data for accurate forecasting.

Build Time Series Models for Gaussian Processes in Python
Time Series Project - A hands-on approach to Gaussian Processes for Time Series Modelling in Python

A/B Testing Approach for Comparing Performance of ML Models
The objective of this project is to compare the performance of BERT and DistilBERT models for building an efficient Question and Answering system. Using A/B testing approach, we explore the effectiveness and efficiency of both models and determine which one is better suited for Q&A tasks.