How to optimise number of trees in XGBoost?

How to optimise number of trees in XGBoost?

How to optimise number of trees in XGBoost?

This recipe helps you optimise number of trees in XGBoost

In [2]:
def Snippet_191():
    print(format('How to optimise number of trees in XGBoost','*^82'))

    import warnings

    # load libraries
    from sklearn import datasets
    from sklearn.model_selection import train_test_split
    from xgboost import XGBClassifier
    from sklearn.model_selection import GridSearchCV
    from sklearn.model_selection import StratifiedKFold
    import matplotlib
    from matplotlib import pyplot
    import matplotlib.pyplot as plt'ggplot')

    # load the iris datasets
    dataset = datasets.load_wine()
    X =; y =
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25)

    # grid search
    model = XGBClassifier()
    n_estimators = range(50, 400, 50)
    param_grid = dict(n_estimators=n_estimators)
    kfold = StratifiedKFold(n_splits=10, shuffle=True, random_state=7)
    grid_search = GridSearchCV(model, param_grid, scoring="neg_log_loss", n_jobs=-1, cv=kfold)
    grid_result =, y)

    # summarize results
    print("Best: %f using %s" % (grid_result.best_score_, grid_result.best_params_))
    means = grid_result.cv_results_['mean_test_score']
    stds = grid_result.cv_results_['std_test_score']
    params = grid_result.cv_results_['params']
    for mean, stdev, param in zip(means, stds, params):
	     print("%f (%f) with: %r" % (mean, stdev, param))

    # plot
    pyplot.errorbar(n_estimators, means, yerr=stds)
    pyplot.title("XGBoost n_estimators vs Log Loss")
    pyplot.ylabel('Log Loss')

********************How to optimise number of trees in XGBoost********************

Best: -0.077742 using {'n_estimators': 250}

-0.108811 (0.060179) with: {'n_estimators': 50}
-0.083225 (0.059937) with: {'n_estimators': 100}
-0.079464 (0.058413) with: {'n_estimators': 150}
-0.077744 (0.057482) with: {'n_estimators': 200}
-0.077742 (0.057480) with: {'n_estimators': 250}
-0.077754 (0.057472) with: {'n_estimators': 300}
-0.077754 (0.057472) with: {'n_estimators': 350}

Relevant Projects

Identifying Product Bundles from Sales Data Using R Language
In this data science project in R, we are going to talk about subjective segmentation which is a clustering technique to find out product bundles in sales data.

Learn to prepare data for your next machine learning project
Text data requires special preparation before you can start using it for any machine learning project.In this ML project, you will learn about applying Machine Learning models to create classifiers and learn how to make sense of textual data.

Anomaly Detection Using Deep Learning and Autoencoders
Deep Learning Project- Learn about implementation of a machine learning algorithm using autoencoders for anomaly detection.

Data Science Project-TalkingData AdTracking Fraud Detection
Machine Learning Project in R-Detect fraudulent click traffic for mobile app ads using R data science programming language.

Data Science Project - Instacart Market Basket Analysis
Data Science Project - Build a recommendation engine which will predict the products to be purchased by an Instacart consumer again.

Ecommerce product reviews - Pairwise ranking and sentiment analysis
This project analyzes a dataset containing ecommerce product reviews. The goal is to use machine learning models to perform sentiment analysis on product reviews and rank them based on relevance. Reviews play a key role in product recommendation systems.

Predict Churn for a Telecom company using Logistic Regression
Machine Learning Project in R- Predict the customer churn of telecom sector and find out the key drivers that lead to churn. Learn how the logistic regression model using R can be used to identify the customer churn in telecom dataset.

Build an Image Classifier for Plant Species Identification
In this machine learning project, we will use binary leaf images and extracted features, including shape, margin, and texture to accurately identify plant species using different benchmark classification techniques.

PySpark Tutorial - Learn to use Apache Spark with Python
PySpark Project-Get a handle on using Python with Spark through this hands-on data processing spark python tutorial.

Solving Multiple Classification use cases Using H2O
In this project, we are going to talk about H2O and functionality in terms of building Machine Learning models.