How to visualise XGBoost tree in Python?

This recipe helps you visualise XGBoost tree in Python

Recipe Objective

Have you ever tried to plot XGBoost tree in python and visualise it in the form of tree. So here, In this recipe we will be training XGBoost Classifier, predicting the output and plot the graph.

So this is the recipe on how we visualise XGBoost tree in Python

Get Closer To Your Dream of Becoming a Data Scientist with 70+ Solved End-to-End ML Projects

Step 1 - Import the library

from sklearn import datasets from sklearn import metrics from xgboost import XGBClassifier, plot_tree from sklearn.model_selection import train_test_split import matplotlib.pyplot as plt plt.style.use('ggplot')

We have imported all the modules that would be needed like metrics, datasets, XGBClassifier , plot_tree etc. We will see the use of each modules step by step further.

Step 2 - Setting up the Data for Classifier

We have imported inbuilt breast_cancer dataset from the module datasets and stored the data in X and the target in y. We have also used train_test_split to split the dataset into two parts such that 30% of data is in test and rest in train. dataset = datasets.load_breast_cancer() X = dataset.data; y = dataset.target X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.30)

Step 3 - Training XGBClassifier and Predicting the output

We have made an object for the model and fitted the train data. Then we have used the test data to test the model by predicting the output from the model for test data. model_XGB = XGBClassifier() model_XGB.fit(X_train, y_train) print(model_XGB) expected_y = y_test predicted_y = model_XGB.predict(X_test)

Step 4 - Calculating the Scores

Now we are calcutaing other scores for the model using classification_report and confusion matrix by passing expected and predicted values of target of test set. print(metrics.classification_report(expected_y, predicted_y, target_names=dataset.target_names)) print(metrics.confusion_matrix(expected_y, predicted_y))

Explore More Data Science and Machine Learning Projects for Practice. Fast-Track Your Career Transition with ProjectPro

Step 5 - Ploting the tree

We are ploting the tree for XGBClassifier by passing the required parameters from plot_tree. plot_tree(model_XGB); plt.show() plot_tree(model_XGB, num_trees=4); plt.show() plot_tree(model_XGB, num_trees=0, rankdir='LR'); plt.show() So the final output comes as:

XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=1,
       colsample_bynode=1, colsample_bytree=1, gamma=0, learning_rate=0.1,
       max_delta_step=0, max_depth=3, min_child_weight=1, missing=None,
       n_estimators=100, n_jobs=1, nthread=None,
       objective='binary:logistic', random_state=0, reg_alpha=0,
       reg_lambda=1, scale_pos_weight=1, seed=None, silent=None,
       subsample=1, verbosity=1)

XGBClassifier: 

              precision    recall  f1-score   support

   malignant       0.99      0.94      0.96        70
      benign       0.96      0.99      0.98       101

   micro avg       0.97      0.97      0.97       171
   macro avg       0.97      0.97      0.97       171
weighted avg       0.97      0.97      0.97       171


[[ 66   4]
 [  1 100]]

Download Materials

What Users are saying..

profile image

Abhinav Agarwal

Graduate Student at Northwestern University
linkedin profile url

I come from Northwestern University, which is ranked 9th in the US. Although the high-quality academics at school taught me all the basics I needed, obtaining practical experience was a challenge.... Read More

Relevant Projects

Build Regression (Linear,Ridge,Lasso) Models in NumPy Python
In this machine learning regression project, you will learn to build NumPy Regression Models (Linear Regression, Ridge Regression, Lasso Regression) from Scratch.

Locality Sensitive Hashing Python Code for Look-Alike Modelling
In this deep learning project, you will find similar images (lookalikes) using deep learning and locality sensitive hashing to find customers who are most likely to click on an ad.

AWS MLOps Project to Deploy Multiple Linear Regression Model
Build and Deploy a Multiple Linear Regression Model in Python on AWS

Machine Learning project for Retail Price Optimization
In this machine learning pricing project, we implement a retail price optimization algorithm using regression trees. This is one of the first steps to building a dynamic pricing model.

Build a Multi ClassText Classification Model using Naive Bayes
Implement the Naive Bayes Algorithm to build a multi class text classification model in Python.

Time Series Project to Build a Multiple Linear Regression Model
Learn to build a Multiple linear regression model in Python on Time Series Data

Build a Multi Class Image Classification Model Python using CNN
This project explains How to build a Sequential Model that can perform Multi Class Image Classification in Python using CNN

A/B Testing Approach for Comparing Performance of ML Models
The objective of this project is to compare the performance of BERT and DistilBERT models for building an efficient Question and Answering system. Using A/B testing approach, we explore the effectiveness and efficiency of both models and determine which one is better suited for Q&A tasks.

Build an End-to-End AWS SageMaker Classification Model
MLOps on AWS SageMaker -Learn to Build an End-to-End Classification Model on SageMaker to predict a patient’s cause of death.

OpenCV Project to Master Advanced Computer Vision Concepts
In this OpenCV project, you will learn to implement advanced computer vision concepts and algorithms in OpenCV library using Python.