How to visualise XGBoost tree in Python?

How to visualise XGBoost tree in Python?

How to visualise XGBoost tree in Python?

This recipe helps you visualise XGBoost tree in Python


Recipe Objective

Have you ever tried to plot XGBoost tree in python and visualise it in the form of tree. So here, In this recipe we will be training XGBoost Classifier, predicting the output and plot the graph.

So this is the recipe on how we visualise XGBoost tree in Python

Step 1 - Import the library

from sklearn import datasets from sklearn import metrics from xgboost import XGBClassifier, plot_tree from sklearn.model_selection import train_test_split import matplotlib.pyplot as plt'ggplot')

We have imported all the modules that would be needed like metrics, datasets, XGBClassifier , plot_tree etc. We will see the use of each modules step by step further.

Step 2 - Setting up the Data for Classifier

We have imported inbuilt breast_cancer dataset from the module datasets and stored the data in X and the target in y. We have also used train_test_split to split the dataset into two parts such that 30% of data is in test and rest in train. dataset = datasets.load_breast_cancer() X =; y = X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.30)

Step 3 - Training XGBClassifier and Predicting the output

We have made an object for the model and fitted the train data. Then we have used the test data to test the model by predicting the output from the model for test data. model_XGB = XGBClassifier(), y_train) print(model_XGB) expected_y = y_test predicted_y = model_XGB.predict(X_test)

Step 4 - Calculating the Scores

Now we are calcutaing other scores for the model using classification_report and confusion matrix by passing expected and predicted values of target of test set. print(metrics.classification_report(expected_y, predicted_y, target_names=dataset.target_names)) print(metrics.confusion_matrix(expected_y, predicted_y))

Step 5 - Ploting the tree

We are ploting the tree for XGBClassifier by passing the required parameters from plot_tree. plot_tree(model_XGB); plot_tree(model_XGB, num_trees=4); plot_tree(model_XGB, num_trees=0, rankdir='LR'); So the final output comes as:

XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=1,
       colsample_bynode=1, colsample_bytree=1, gamma=0, learning_rate=0.1,
       max_delta_step=0, max_depth=3, min_child_weight=1, missing=None,
       n_estimators=100, n_jobs=1, nthread=None,
       objective='binary:logistic', random_state=0, reg_alpha=0,
       reg_lambda=1, scale_pos_weight=1, seed=None, silent=None,
       subsample=1, verbosity=1)


              precision    recall  f1-score   support

   malignant       0.99      0.94      0.96        70
      benign       0.96      0.99      0.98       101

   micro avg       0.97      0.97      0.97       171
   macro avg       0.97      0.97      0.97       171
weighted avg       0.97      0.97      0.97       171

[[ 66   4]
 [  1 100]]

Relevant Projects

Solving Multiple Classification use cases Using H2O
In this project, we are going to talk about H2O and functionality in terms of building Machine Learning models.

Predict Census Income using Deep Learning Models
In this project, we are going to work on Deep Learning using H2O to predict Census income.

Learn to prepare data for your next machine learning project
Text data requires special preparation before you can start using it for any machine learning project.In this ML project, you will learn about applying Machine Learning models to create classifiers and learn how to make sense of textual data.

Natural language processing Chatbot application using NLTK for text classification
In this NLP AI application, we build the core conversational engine for a chatbot. We use the popular NLTK text classification library to achieve this.

Predict Employee Computer Access Needs in Python
Data Science Project in Python- Given his or her job role, predict employee access needs using amazon employee database.

Choosing the right Time Series Forecasting Methods
There are different time series forecasting methods to forecast stock price, demand etc. In this machine learning project, you will learn to determine which forecasting method to be used when and how to apply with time series forecasting example.

Data Science Project on Wine Quality Prediction in R
In this R data science project, we will explore wine dataset to assess red wine quality. The objective of this data science project is to explore which chemical properties will influence the quality of red wines.

Deep Learning with Keras in R to Predict Customer Churn
In this deep learning project, we will predict customer churn using Artificial Neural Networks and learn how to model an ANN in R with the keras deep learning package.

Forecast Inventory demand using historical sales data in R
In this machine learning project, you will develop a machine learning model to accurately forecast inventory demand based on historical sales data.

Customer Market Basket Analysis using Apriori and Fpgrowth algorithms
In this data science project, you will learn how to perform market basket analysis with the application of Apriori and FP growth algorithms based on the concept of association rule learning.