How to visualise XGBoost tree in R?

This recipe helps you visualise XGBoost tree in R

Recipe Objective

Classification and regression are supervised learning models that can be solved using algorithms like linear regression / logistics regression, decision tree, etc. But these are not competitive in terms of producing a good prediction accuracy.Ensemble techniques, on the other hand, create multiple models and combine them into one to produce effective results.

Bagging, boosting, random forest, are different types of ensemble techniques. Boosting is a sequential ensemble technique in which the model is improved using the information from previously grown weaker models. This process is continued for multiple iterations until a final model is built which will predict a more accurate outcome. ​

There are 3 types of boosting techniques: ​

  1. Adaboost
  2. Gradient Descent.
  3. Xgboost

Recently, researchers and enthusiasts have started using ensemble techniques like XGBoost to win data science competitions and hackathons. It outperforms algorithms such as Random Forest and Gadient Boosting in terms of speed as well as accuracy when performed on structured data. ​

XGBoost uses ensemble model which is based on Decision tree. A simple decision tree is considered to be a weak learner. The algorithm build sequential decision trees were each tree corrects the error occuring in the previous one until a condition is met.

In this recipe, we will discuss how to build and visualise XGBoost Tree.. ​

STEP 1: Importing Necessary Libraries

library(caret) # for general data preparation and model fitting library(rpart.plot) library(tidyverse)

STEP 2: Read a csv file and explore the data

The dataset attached contains the data of 160 different bags associated with ABC industries.

The bags have certain attributes which are described below: ​

  1. Height – The height of the bag
  2. Width – The width of the bag
  3. Length – The length of the bag
  4. Weight – The weight the bag can carry
  5. Weight1 – Weight the bag can carry after expansion

The company now wants to predict the cost they should set for a new variant of these kinds of bags. ​

data <- read.csv("R_356_Data_1.csv") glimpse(data)

Rows: 159
Columns: 6
$ Cost     242, 290, 340, 363, 430, 450, 500, 390, 450, 500, 475, 500,...
$ Weight   23.2, 24.0, 23.9, 26.3, 26.5, 26.8, 26.8, 27.6, 27.6, 28.5,...
$ Weight1  25.4, 26.3, 26.5, 29.0, 29.0, 29.7, 29.7, 30.0, 30.0, 30.7,...
$ Length   30.0, 31.2, 31.1, 33.5, 34.0, 34.7, 34.5, 35.0, 35.1, 36.2,...
$ Height   11.5200, 12.4800, 12.3778, 12.7300, 12.4440, 13.6024, 14.17...
$ Width    4.0200, 4.3056, 4.6961, 4.4555, 5.1340, 4.9274, 5.2785, 4.6...

summary(data) # returns the statistical summary of the data columns

Cost            Weight         Weight1          Length     
 Min.   :   0.0   Min.   : 7.50   Min.   : 8.40   Min.   : 8.80  
 1st Qu.: 120.0   1st Qu.:19.05   1st Qu.:21.00   1st Qu.:23.15  
 Median : 273.0   Median :25.20   Median :27.30   Median :29.40  
 Mean   : 398.3   Mean   :26.25   Mean   :28.42   Mean   :31.23  
 3rd Qu.: 650.0   3rd Qu.:32.70   3rd Qu.:35.50   3rd Qu.:39.65  
 Max.   :1650.0   Max.   :59.00   Max.   :63.40   Max.   :68.00  
     Height           Width      
 Min.   : 1.728   Min.   :1.048  
 1st Qu.: 5.945   1st Qu.:3.386  
 Median : 7.786   Median :4.248  
 Mean   : 8.971   Mean   :4.417  
 3rd Qu.:12.366   3rd Qu.:5.585  
 Max.   :18.957   Max.   :8.142   

dim(data)

159 6

STEP 3: Train Test Split

# createDataPartition() function from the caret package to split the original dataset into a training and testing set and split data into training (80%) and testing set (20%) parts = createDataPartition(data$Cost, p = .8, list = F) train = data[parts, ] test = data[-parts, ] #define predictor and response variables in training set train_x = data.matrix(train[, -1]) train_y = train[,1] #define predictor and response variables in testing set test_x = data.matrix(test[, -1]) test_y = test[, 1] #define final training and testing sets xgb_train = xgb.DMatrix(data = train_x, label = train_y) xgb_test = xgb.DMatrix(data = test_x, label = test_y)

STEP 4: Create a xgboost model

Now, we will fit and train our model using the xgb.train() function, which will result in corresponding training and testing root mean squared error for each round.

Here, the max.depth paramter deermines how deep the tree should grow, we choose a value of 3.

#defining a watchlist watchlist = list(train=xgb_train, test=xgb_test) #fit XGBoost model and display training and testing data at each iteartion model = xgb.train(data = xgb_train, max.depth = 3, watchlist=watchlist, nrounds = 100)

[1]	train-rmse:374.441406	test-rmse:481.788391 
[2]	train-rmse:274.574158	test-rmse:377.512909 
[3]	train-rmse:204.863098	test-rmse:306.634033 
[4]	train-rmse:155.649658	test-rmse:251.804932 
[5]	train-rmse:119.886559	test-rmse:206.584793 
[6]	train-rmse:94.443649	test-rmse:170.362732 
[7]	train-rmse:76.098549	test-rmse:157.283279 
[8]	train-rmse:63.038189	test-rmse:148.384521 
[9]	train-rmse:53.171177	test-rmse:142.591125 
[10]	train-rmse:46.219536	test-rmse:126.492058 
[11]	train-rmse:41.068180	test-rmse:112.861725 
[12]	train-rmse:37.273392	test-rmse:101.792809 
[13]	train-rmse:33.991714	test-rmse:99.646431 
[14]	train-rmse:31.665110	test-rmse:91.611916 
[15]	train-rmse:29.955919	test-rmse:84.864738 
[16]	train-rmse:28.531353	test-rmse:79.398239 
[17]	train-rmse:27.040276	test-rmse:74.698051 
[18]	train-rmse:26.302597	test-rmse:70.936241 
[19]	train-rmse:25.201057	test-rmse:67.750641 
[20]	train-rmse:24.487757	test-rmse:65.076195 
[21]	train-rmse:23.867445	test-rmse:65.166847 
[22]	train-rmse:22.876081	test-rmse:63.112698 
[23]	train-rmse:22.164562	test-rmse:61.523403 
[24]	train-rmse:21.816034	test-rmse:61.467430 
[25]	train-rmse:21.125587	test-rmse:61.402748 
[26]	train-rmse:20.957186	test-rmse:60.343128 
[27]	train-rmse:20.365843	test-rmse:60.348598 
[28]	train-rmse:20.168547	test-rmse:59.282814 
[29]	train-rmse:18.995090	test-rmse:58.969128 
[30]	train-rmse:18.819603	test-rmse:59.020538 
[31]	train-rmse:18.699118	test-rmse:58.379250 
[32]	train-rmse:17.504850	test-rmse:57.781509 
[33]	train-rmse:17.387026	test-rmse:57.645771 
[34]	train-rmse:17.037064	test-rmse:57.125183 
[35]	train-rmse:16.668007	test-rmse:56.830990 
[36]	train-rmse:16.044168	test-rmse:56.780052 
[37]	train-rmse:15.536475	test-rmse:56.567234 
[38]	train-rmse:15.433763	test-rmse:56.546337 
[39]	train-rmse:15.098138	test-rmse:56.664021 
[40]	train-rmse:14.819264	test-rmse:56.322807 
[41]	train-rmse:14.625785	test-rmse:56.316051 
[42]	train-rmse:14.350323	test-rmse:56.248844 
[43]	train-rmse:14.131385	test-rmse:56.189671 
[44]	train-rmse:13.516161	test-rmse:56.011814 
[45]	train-rmse:13.048274	test-rmse:56.140182 
[46]	train-rmse:12.758994	test-rmse:55.925411 
[47]	train-rmse:12.444994	test-rmse:56.098057 
[48]	train-rmse:12.082834	test-rmse:56.063778 
[49]	train-rmse:11.696443	test-rmse:56.002361 
[50]	train-rmse:11.560493	test-rmse:56.020744 
[51]	train-rmse:11.102805	test-rmse:56.114948 
[52]	train-rmse:10.627243	test-rmse:56.106552 
[53]	train-rmse:10.547875	test-rmse:56.181263 
[54]	train-rmse:10.363978	test-rmse:55.970352 
[55]	train-rmse:10.133872	test-rmse:56.034210 
[56]	train-rmse:9.734212	test-rmse:56.160725 
[57]	train-rmse:9.508077	test-rmse:56.177059 
[58]	train-rmse:9.202065	test-rmse:56.142998 
[59]	train-rmse:8.973363	test-rmse:56.266232 
[60]	train-rmse:8.868542	test-rmse:55.894310 
[61]	train-rmse:8.712481	test-rmse:55.797413 
[62]	train-rmse:8.450444	test-rmse:55.796597 
[63]	train-rmse:8.261618	test-rmse:55.789951 
[64]	train-rmse:8.081842	test-rmse:55.639320 
[65]	train-rmse:7.938920	test-rmse:55.682808 
[66]	train-rmse:7.682938	test-rmse:55.756508 
[67]	train-rmse:7.553942	test-rmse:55.836765 
[68]	train-rmse:7.432102	test-rmse:55.685822 
[69]	train-rmse:7.294747	test-rmse:55.697899 
[70]	train-rmse:7.103888	test-rmse:55.749569 
[71]	train-rmse:6.905044	test-rmse:55.763145 
[72]	train-rmse:6.753871	test-rmse:55.844006 
[73]	train-rmse:6.690207	test-rmse:55.758812 
[74]	train-rmse:6.488435	test-rmse:55.740368 
[75]	train-rmse:6.299417	test-rmse:55.737957 
[76]	train-rmse:6.090727	test-rmse:55.710434 
[77]	train-rmse:5.966695	test-rmse:55.743229 
[78]	train-rmse:5.857632	test-rmse:55.720600 
[79]	train-rmse:5.828579	test-rmse:55.569942 
[80]	train-rmse:5.622557	test-rmse:55.612438 
[81]	train-rmse:5.563855	test-rmse:55.506058 
[82]	train-rmse:5.381149	test-rmse:55.447449 
[83]	train-rmse:5.306352	test-rmse:55.385094 
[84]	train-rmse:5.159195	test-rmse:55.371307 
[85]	train-rmse:5.009599	test-rmse:55.202850 
[86]	train-rmse:4.988478	test-rmse:55.135273 
[87]	train-rmse:4.858966	test-rmse:55.196877 
[88]	train-rmse:4.761164	test-rmse:55.197235 
[89]	train-rmse:4.649740	test-rmse:55.199398 
[90]	train-rmse:4.545322	test-rmse:55.266251 
[91]	train-rmse:4.471013	test-rmse:55.323376 
[92]	train-rmse:4.442612	test-rmse:55.336811 
[93]	train-rmse:4.399715	test-rmse:55.298866 
[94]	train-rmse:4.289005	test-rmse:55.273613 
[95]	train-rmse:4.196774	test-rmse:55.273048 
[96]	train-rmse:4.097926	test-rmse:55.231594 
[97]	train-rmse:3.942547	test-rmse:55.206097 
[98]	train-rmse:3.923210	test-rmse:55.145107 
[99]	train-rmse:3.835154	test-rmse:55.166672 
[100]	train-rmse:3.761758	test-rmse:55.160030 

#define final model model_xgboost = xgboost(data = xgb_train, max.depth = 3, nrounds = 86, verbose = 0) summary(model_xgboost)

               Length Class              Mode       
handle             1  xgb.Booster.handle externalptr
raw            91316  -none-             raw        
niter              1  -none-             numeric    
evaluation_log     2  data.table         list       
call              14  -none-             call       
params             2  -none-             list       
callbacks          1  -none-             list       
feature_names      5  -none-             character  
nfeatures          1  -none-             numeric    

STEP 5: Visualising an xgboost Tree

We will use xgb.plot.tree(model = ) to plot all the trees

# a plot with all the trees xgb.plot.tree(model = model_xgboost) # this seems to be a mess. Hence, we only stick to 1 tree at a time. The below code is to plot first tree and show its node ID xgb.plot.tree(model = model_xgboost, trees = 0, show_node_id = TRUE)

What Users are saying..

profile image

Abhinav Agarwal

Graduate Student at Northwestern University
linkedin profile url

I come from Northwestern University, which is ranked 9th in the US. Although the high-quality academics at school taught me all the basics I needed, obtaining practical experience was a challenge.... Read More

Relevant Projects

Digit Recognition using CNN for MNIST Dataset in Python
In this deep learning project, you will build a convolutional neural network using MNIST dataset for handwritten digit recognition.

Natural language processing Chatbot application using NLTK for text classification
In this NLP AI application, we build the core conversational engine for a chatbot. We use the popular NLTK text classification library to achieve this.

A/B Testing Approach for Comparing Performance of ML Models
The objective of this project is to compare the performance of BERT and DistilBERT models for building an efficient Question and Answering system. Using A/B testing approach, we explore the effectiveness and efficiency of both models and determine which one is better suited for Q&A tasks.

Recommender System Machine Learning Project for Beginners-1
Recommender System Machine Learning Project for Beginners - Learn how to design, implement and train a rule-based recommender system in Python

Build a Collaborative Filtering Recommender System in Python
Use the Amazon Reviews/Ratings dataset of 2 Million records to build a recommender system using memory-based collaborative filtering in Python.

End-to-End ML Model Monitoring using Airflow and Docker
In this MLOps Project, you will learn to build an end to end pipeline to monitor any changes in the predictive power of model or degradation of data.

PyTorch Project to Build a GAN Model on MNIST Dataset
In this deep learning project, you will learn how to build a GAN Model on MNIST Dataset for generating new images of handwritten digits.

MLOps AWS Project on Topic Modeling using Gunicorn Flask
In this project we will see the end-to-end machine learning development process to design, build and manage reproducible, testable, and evolvable machine learning models by using AWS

NLP Project to Build a Resume Parser in Python using Spacy
Use the popular Spacy NLP python library for OCR and text classification to build a Resume Parser in Python.

Time Series Forecasting with LSTM Neural Network Python
Deep Learning Project- Learn to apply deep learning paradigm to forecast univariate time series data.