How to visualise XGBoost feature importance in Python?

How to visualise XGBoost feature importance in Python?

How to visualise XGBoost feature importance in Python?

This recipe helps you visualise XGBoost feature importance in Python


Recipe Objective

So many a times it happens that we need to find the important features for training the data. We also need to choose this when there are large number of features and it takes much computational cost to train the data. We can get the important features by XGBoost.

So this is the recipe on How we can visualise XGBoost feature importance in Python.

Step 1 - Import the library

from sklearn import datasets from sklearn import metrics from sklearn.model_selection import train_test_split from xgboost import XGBClassifier, plot_importance import matplotlib.pyplot as plt

We have imported various modules from differnt libraries such as datasets, metrics,test_train_split, XGBClassifier, plot_importance and plt.

Step 2 - Setting up the Data

We are using the inbuilt breast cancer dataset to train the model and we used train_test_split to split the data into two parts train and test. dataset = datasets.load_breast_cancer() X =; y = X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25)

Step 3 - Training the Model

So we have called XGBClassifier and fitted out test data in it and after that we have made two objects one for the original value of y_test and another for predicted values by model. model = XGBClassifier(), y_train) print(model) expected_y = y_test predicted_y = model.predict(X_test)

Step 4 - Printing the results and ploting the graph

So finally we are printing the results such as confusion_matrix and classification_report. We are also using bar graph to visualize the importance of the features. print(); print('XGBClassifier: ') print(); print(metrics.classification_report(expected_y, predicted_y, target_names=dataset.target_names)) print(); print(metrics.confusion_matrix(expected_y, predicted_y)), model.feature_importances_) plt.barh(range(len(model.feature_importances_)), model.feature_importances_) plot_importance(model) Output of this snippet is given below:

XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=1,
       colsample_bynode=1, colsample_bytree=1, gamma=0, learning_rate=0.1,
       max_delta_step=0, max_depth=3, min_child_weight=1, missing=None,
       n_estimators=100, n_jobs=1, nthread=None,
       objective='binary:logistic', random_state=0, reg_alpha=0,
       reg_lambda=1, scale_pos_weight=1, seed=None, silent=None,
       subsample=1, verbosity=1)


              precision    recall  f1-score   support

   malignant       0.98      0.96      0.97        53
      benign       0.98      0.99      0.98        90

   micro avg       0.98      0.98      0.98       143
   macro avg       0.98      0.98      0.98       143
weighted avg       0.98      0.98      0.98       143

[[51  2]
 [ 1 89]]

Relevant Projects

Music Recommendation System Project using Python and R
Machine Learning Project - Work with KKBOX's Music Recommendation System dataset to build the best music recommendation engine.

Build an Image Classifier for Plant Species Identification
In this machine learning project, we will use binary leaf images and extracted features, including shape, margin, and texture to accurately identify plant species using different benchmark classification techniques.

Predict Credit Default | Give Me Some Credit Kaggle
In this data science project, you will predict borrowers chance of defaulting on credit loans by building a credit score prediction model.

Deep Learning with Keras in R to Predict Customer Churn
In this deep learning project, we will predict customer churn using Artificial Neural Networks and learn how to model an ANN in R with the keras deep learning package.

Natural language processing Chatbot application using NLTK for text classification
In this NLP AI application, we build the core conversational engine for a chatbot. We use the popular NLTK text classification library to achieve this.

Choosing the right Time Series Forecasting Methods
There are different time series forecasting methods to forecast stock price, demand etc. In this machine learning project, you will learn to determine which forecasting method to be used when and how to apply with time series forecasting example.

Predict Census Income using Deep Learning Models
In this project, we are going to work on Deep Learning using H2O to predict Census income.

Time Series Forecasting with LSTM Neural Network Python
Deep Learning Project- Learn to apply deep learning paradigm to forecast univariate time series data.

Data Science Project on Wine Quality Prediction in R
In this R data science project, we will explore wine dataset to assess red wine quality. The objective of this data science project is to explore which chemical properties will influence the quality of red wines.

Predict Churn for a Telecom company using Logistic Regression
Machine Learning Project in R- Predict the customer churn of telecom sector and find out the key drivers that lead to churn. Learn how the logistic regression model using R can be used to identify the customer churn in telecom dataset.