How to optimise size depth of trees in XGBoost?

How to optimise size depth of trees in XGBoost?

How to optimise size depth of trees in XGBoost?

This recipe helps you optimise size (depth) of trees in XGBoost

In [2]:
def Snippet_192():
    print(format('How to optimise size (depth) of trees in XGBoost','*^82'))

    import warnings

    # load libraries
    from sklearn import datasets
    from sklearn.model_selection import train_test_split
    from xgboost import XGBClassifier
    from sklearn.model_selection import GridSearchCV
    from sklearn.model_selection import StratifiedKFold
    import matplotlib
    from matplotlib import pyplot

    # load the iris datasets
    dataset = datasets.load_wine()
    X =; y =
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25)

    # grid search
    model = XGBClassifier()
    max_depth = range(1, 11, 2)
    param_grid = dict(max_depth=max_depth)
    kfold = StratifiedKFold(n_splits=10, shuffle=True, random_state=7)
    grid_search = GridSearchCV(model, param_grid, scoring="neg_log_loss", n_jobs=-1, cv=kfold, verbose=1)
    grid_result =, y)

    # summarize results
    print("Best: %f using %s" % (grid_result.best_score_, grid_result.best_params_))
    means = grid_result.cv_results_['mean_test_score']
    stds = grid_result.cv_results_['std_test_score']
    params = grid_result.cv_results_['params']

    for mean, stdev, param in zip(means, stds, params):
	     print("%f (%f) with: %r" % (mean, stdev, param))
         # plot
    pyplot.errorbar(max_depth, means, yerr=stds)
    pyplot.title("XGBoost max_depth vs Log Loss")
    pyplot.ylabel('Log Loss')

*****************How to optimise size (depth) of trees in XGBoost*****************
range(1, 11, 2)
Fitting 10 folds for each of 5 candidates, totalling 50 fits
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 4 concurrent workers.
[Parallel(n_jobs=-1)]: Done  43 out of  50 | elapsed:    2.1s remaining:    0.3s
[Parallel(n_jobs=-1)]: Done  50 out of  50 | elapsed:    2.2s finished
Best: -0.069259 using {'max_depth': 1}

-0.069259 (0.034427) with: {'max_depth': 1}
-0.083225 (0.059937) with: {'max_depth': 3}
-0.086606 (0.061344) with: {'max_depth': 5}
-0.086606 (0.061344) with: {'max_depth': 7}
-0.086606 (0.061344) with: {'max_depth': 9}

Relevant Projects

Build an Image Classifier for Plant Species Identification
In this machine learning project, we will use binary leaf images and extracted features, including shape, margin, and texture to accurately identify plant species using different benchmark classification techniques.

Perform Time series modelling using Facebook Prophet
In this project, we are going to talk about Time Series Forecasting to predict the electricity requirement for a particular house using Prophet.

Human Activity Recognition Using Smartphones Data Set
In this deep learning project, you will build a classification system where to precisely identify human fitness activities.

Deep Learning with Keras in R to Predict Customer Churn
In this deep learning project, we will predict customer churn using Artificial Neural Networks and learn how to model an ANN in R with the keras deep learning package.

Data Science Project-TalkingData AdTracking Fraud Detection
Machine Learning Project in R-Detect fraudulent click traffic for mobile app ads using R data science programming language.

German Credit Dataset Analysis to Classify Loan Applications
In this data science project, you will work with German credit dataset using classification techniques like Decision Tree, Neural Networks etc to classify loan applications using R.

Identifying Product Bundles from Sales Data Using R Language
In this data science project in R, we are going to talk about subjective segmentation which is a clustering technique to find out product bundles in sales data.

Walmart Sales Forecasting Data Science Project
Data Science Project in R-Predict the sales for each department using historical markdown data from the Walmart dataset containing data of 45 Walmart stores.

Predict Macro Economic Trends using Kaggle Financial Dataset
In this machine learning project, you will uncover the predictive value in an uncertain world by using various artificial intelligence, machine learning, advanced regression and feature transformation techniques.

Ecommerce product reviews - Pairwise ranking and sentiment analysis
This project analyzes a dataset containing ecommerce product reviews. The goal is to use machine learning models to perform sentiment analysis on product reviews and rank them based on relevance. Reviews play a key role in product recommendation systems.