How to optimise size depth of trees in XGBoost?

How to optimise size depth of trees in XGBoost?

How to optimise size depth of trees in XGBoost?

This recipe helps you optimise size (depth) of trees in XGBoost

In [2]:
def Snippet_192():
    print(format('How to optimise size (depth) of trees in XGBoost','*^82'))

    import warnings

    # load libraries
    from sklearn import datasets
    from sklearn.model_selection import train_test_split
    from xgboost import XGBClassifier
    from sklearn.model_selection import GridSearchCV
    from sklearn.model_selection import StratifiedKFold
    import matplotlib
    from matplotlib import pyplot

    # load the iris datasets
    dataset = datasets.load_wine()
    X =; y =
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25)

    # grid search
    model = XGBClassifier()
    max_depth = range(1, 11, 2)
    param_grid = dict(max_depth=max_depth)
    kfold = StratifiedKFold(n_splits=10, shuffle=True, random_state=7)
    grid_search = GridSearchCV(model, param_grid, scoring="neg_log_loss", n_jobs=-1, cv=kfold, verbose=1)
    grid_result =, y)

    # summarize results
    print("Best: %f using %s" % (grid_result.best_score_, grid_result.best_params_))
    means = grid_result.cv_results_['mean_test_score']
    stds = grid_result.cv_results_['std_test_score']
    params = grid_result.cv_results_['params']

    for mean, stdev, param in zip(means, stds, params):
	     print("%f (%f) with: %r" % (mean, stdev, param))
         # plot
    pyplot.errorbar(max_depth, means, yerr=stds)
    pyplot.title("XGBoost max_depth vs Log Loss")
    pyplot.ylabel('Log Loss')

*****************How to optimise size (depth) of trees in XGBoost*****************
range(1, 11, 2)
Fitting 10 folds for each of 5 candidates, totalling 50 fits
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 4 concurrent workers.
[Parallel(n_jobs=-1)]: Done  43 out of  50 | elapsed:    2.1s remaining:    0.3s
[Parallel(n_jobs=-1)]: Done  50 out of  50 | elapsed:    2.2s finished
Best: -0.069259 using {'max_depth': 1}

-0.069259 (0.034427) with: {'max_depth': 1}
-0.083225 (0.059937) with: {'max_depth': 3}
-0.086606 (0.061344) with: {'max_depth': 5}
-0.086606 (0.061344) with: {'max_depth': 7}
-0.086606 (0.061344) with: {'max_depth': 9}

Relevant Projects

Predict Credit Default | Give Me Some Credit Kaggle
In this data science project, you will predict borrowers chance of defaulting on credit loans by building a credit score prediction model.

Build an Image Classifier for Plant Species Identification
In this machine learning project, we will use binary leaf images and extracted features, including shape, margin, and texture to accurately identify plant species using different benchmark classification techniques.

Machine Learning project for Retail Price Optimization
In this machine learning pricing project, we implement a retail price optimization algorithm using regression trees. This is one of the first steps to building a dynamic pricing model.

Walmart Sales Forecasting Data Science Project
Data Science Project in R-Predict the sales for each department using historical markdown data from the Walmart dataset containing data of 45 Walmart stores.

Customer Churn Prediction Analysis using Ensemble Techniques
In this machine learning churn project, we implement a churn prediction model in python using ensemble techniques.

Data Science Project - Instacart Market Basket Analysis
Data Science Project - Build a recommendation engine which will predict the products to be purchased by an Instacart consumer again.

Predict Churn for a Telecom company using Logistic Regression
Machine Learning Project in R- Predict the customer churn of telecom sector and find out the key drivers that lead to churn. Learn how the logistic regression model using R can be used to identify the customer churn in telecom dataset.

Choosing the right Time Series Forecasting Methods
There are different time series forecasting methods to forecast stock price, demand etc. In this machine learning project, you will learn to determine which forecasting method to be used when and how to apply with time series forecasting example.

Perform Time series modelling using Facebook Prophet
In this project, we are going to talk about Time Series Forecasting to predict the electricity requirement for a particular house using Prophet.

Forecast Inventory demand using historical sales data in R
In this machine learning project, you will develop a machine learning model to accurately forecast inventory demand based on historical sales data.