Many a times while working on a dataset and using a Machine Learning model we don't know which set of hyperparameters will give us the best result. Passing all sets of hyperparameters manually through the model and checking the result might be a hectic work and may not be possible to do.
To get the best set of hyperparameters we can use Grid Search. Grid Search passes all combinations of hyperparameters one by one into the model and check the result. Finally it gives us the set of hyperparemeters which gives the best result after passing in the model.
This python source code does the following:
1. pip install Catboost
2. Imports SKlearn dataset
3. Performs validation dataset from the existing dataset
4. Applies Catboost Regressor
5. Hyperparameter tuning using GridSearchCV
So this recipe is a short example of how we can find optimal parameters for CatBoost using GridSearchCV for Regression.
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.model_selection import GridSearchCV
from catboost import CatBoostRegressor
Here we have imported various modules like datasets, CatBoostRegressor and GridSearchCV from differnt libraries. We will understand the use of these later while using it in the in the code snipet.
For now just have a look on these imports.
Here we have used datasets to load the inbuilt iris dataset and we have created objects X and y to store the data and the target value respectively.
dataset = datasets.load_iris()
X = dataset.data; y = dataset.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.30)
Here, we are using CatBoostRegressor as a Machine Learning model to use GridSearchCV. So we have created an object model_CBR.
model_CBR = CatBoostRegressor()
Now we have defined the parameters of the model which we want to pass to through GridSearchCV to get the best parameters. So we are making an dictionary called parameters in which we have four parameters learning_rate, depth and iteration.
parameters = {'depth' : [6,8,10],
'learning_rate' : [0.01, 0.05, 0.1],
'iterations' : [30, 50, 100]
}
Before using GridSearchCV, lets have a look on the important parameters.
grid = GridSearchCV(estimator=model_CBR, param_grid = parameters, cv = 2, n_jobs=-1)
grid.fit(X_train, y_train)
Now we are using print statements to print the results. It will give the values of hyperparameters as a result.
print(" Results from Grid Search " )
print("\n The best estimator across ALL searched params:\n", grid.best_estimator_)
print("\n The best score across ALL searched params:\n", grid.best_score_)
print("\n The best parameters across ALL searched params:\n", grid.best_params_)
As an output we get:
0: learn: 0.7716436 total: 46.3ms remaining: 4.58s 1: learn: 0.7414652 total: 46.6ms remaining: 2.28s 2: learn: 0.7125578 total: 47.2ms remaining: 1.52s 3: learn: 0.6871347 total: 47.6ms remaining: 1.14s 4: learn: 0.6621916 total: 48.3ms remaining: 918ms 5: learn: 0.6370111 total: 50.7ms remaining: 794ms 6: learn: 0.6165412 total: 51.1ms remaining: 678ms 7: learn: 0.5926945 total: 51.4ms remaining: 591ms 8: learn: 0.5704622 total: 51.7ms remaining: 523ms 9: learn: 0.5497470 total: 52.1ms remaining: 469ms 10: learn: 0.5285706 total: 52.5ms remaining: 424ms 11: learn: 0.5102976 total: 52.9ms remaining: 388ms 12: learn: 0.4927243 total: 53.3ms remaining: 357ms 13: learn: 0.4767788 total: 53.8ms remaining: 330ms 14: learn: 0.4584534 total: 54.2ms remaining: 307ms 15: learn: 0.4416577 total: 54.6ms remaining: 287ms 16: learn: 0.4258021 total: 55ms remaining: 269ms 17: learn: 0.4106832 total: 55.4ms remaining: 252ms 18: learn: 0.3974296 total: 55.8ms remaining: 238ms 19: learn: 0.3869505 total: 56.2ms remaining: 225ms 20: learn: 0.3731509 total: 56.6ms remaining: 213ms 21: learn: 0.3615482 total: 57ms remaining: 202ms 22: learn: 0.3501406 total: 57.4ms remaining: 192ms 23: learn: 0.3386937 total: 57.8ms remaining: 183ms 24: learn: 0.3269810 total: 58.1ms remaining: 174ms 25: learn: 0.3172767 total: 58.5ms remaining: 166ms 26: learn: 0.3078365 total: 58.9ms remaining: 159ms 27: learn: 0.2989866 total: 59.4ms remaining: 153ms 28: learn: 0.2907521 total: 59.7ms remaining: 146ms 29: learn: 0.2820723 total: 60.1ms remaining: 140ms 30: learn: 0.2732105 total: 60.5ms remaining: 135ms 31: learn: 0.2658956 total: 60.9ms remaining: 129ms 32: learn: 0.2597752 total: 61.3ms remaining: 124ms 33: learn: 0.2519285 total: 61.7ms remaining: 120ms 34: learn: 0.2449226 total: 62.1ms remaining: 115ms 35: learn: 0.2396648 total: 62.5ms remaining: 111ms 36: learn: 0.2327188 total: 62.9ms remaining: 107ms 37: learn: 0.2271869 total: 63.3ms remaining: 103ms 38: learn: 0.2212449 total: 63.7ms remaining: 99.7ms 39: learn: 0.2160455 total: 64.1ms remaining: 96.2ms 40: learn: 0.2105444 total: 64.5ms remaining: 92.9ms 41: learn: 0.2049493 total: 65ms remaining: 89.7ms 42: learn: 0.1992581 total: 65.3ms remaining: 86.6ms 43: learn: 0.1950601 total: 65.7ms remaining: 83.7ms 44: learn: 0.1905929 total: 66.1ms remaining: 80.8ms 45: learn: 0.1864159 total: 66.6ms remaining: 78.1ms 46: learn: 0.1827717 total: 66.9ms remaining: 75.4ms 47: learn: 0.1787064 total: 67.3ms remaining: 72.9ms 48: learn: 0.1747380 total: 67.8ms remaining: 70.6ms 49: learn: 0.1712025 total: 68.2ms remaining: 68.2ms 50: learn: 0.1677444 total: 68.7ms remaining: 66ms 51: learn: 0.1644736 total: 69.1ms remaining: 63.8ms 52: learn: 0.1615005 total: 69.5ms remaining: 61.6ms 53: learn: 0.1589166 total: 69.9ms remaining: 59.6ms 54: learn: 0.1561049 total: 70.3ms remaining: 57.5ms 55: learn: 0.1535928 total: 71.1ms remaining: 55.8ms 56: learn: 0.1507811 total: 71.5ms remaining: 53.9ms 57: learn: 0.1490974 total: 72.4ms remaining: 52.4ms 58: learn: 0.1466473 total: 72.8ms remaining: 50.6ms 59: learn: 0.1455853 total: 73ms remaining: 48.7ms 60: learn: 0.1433409 total: 73.4ms remaining: 46.9ms 61: learn: 0.1413980 total: 73.8ms remaining: 45.2ms 62: learn: 0.1399844 total: 74.6ms remaining: 43.8ms 63: learn: 0.1372188 total: 75ms remaining: 42.2ms 64: learn: 0.1356385 total: 75.4ms remaining: 40.6ms 65: learn: 0.1327448 total: 77.5ms remaining: 39.9ms 66: learn: 0.1303423 total: 77.8ms remaining: 38.3ms 67: learn: 0.1277835 total: 78.3ms remaining: 36.8ms 68: learn: 0.1261218 total: 78.7ms remaining: 35.3ms 69: learn: 0.1236443 total: 79.1ms remaining: 33.9ms 70: learn: 0.1217924 total: 79.5ms remaining: 32.5ms 71: learn: 0.1204415 total: 79.9ms remaining: 31.1ms 72: learn: 0.1190207 total: 80.3ms remaining: 29.7ms 73: learn: 0.1174460 total: 80.7ms remaining: 28.3ms 74: learn: 0.1159067 total: 81ms remaining: 27ms 75: learn: 0.1148673 total: 81.5ms remaining: 25.7ms 76: learn: 0.1128619 total: 81.9ms remaining: 24.5ms 77: learn: 0.1114413 total: 82.3ms remaining: 23.2ms 78: learn: 0.1098260 total: 82.7ms remaining: 22ms 79: learn: 0.1085060 total: 83.2ms remaining: 20.8ms 80: learn: 0.1071127 total: 83.6ms remaining: 19.6ms 81: learn: 0.1056581 total: 84ms remaining: 18.4ms 82: learn: 0.1041074 total: 84.4ms remaining: 17.3ms 83: learn: 0.1028824 total: 84.7ms remaining: 16.1ms 84: learn: 0.1011530 total: 85.1ms remaining: 15ms 85: learn: 0.1007926 total: 85.3ms remaining: 13.9ms 86: learn: 0.0994513 total: 85.8ms remaining: 12.8ms 87: learn: 0.0982469 total: 86.1ms remaining: 11.7ms 88: learn: 0.0968327 total: 86.5ms remaining: 10.7ms 89: learn: 0.0953617 total: 86.9ms remaining: 9.66ms 90: learn: 0.0941489 total: 87.2ms remaining: 8.63ms 91: learn: 0.0932054 total: 87.6ms remaining: 7.62ms 92: learn: 0.0920791 total: 88.1ms remaining: 6.63ms 93: learn: 0.0910071 total: 88.4ms remaining: 5.64ms 94: learn: 0.0896883 total: 88.8ms remaining: 4.67ms 95: learn: 0.0884174 total: 89.2ms remaining: 3.71ms 96: learn: 0.0873079 total: 89.6ms remaining: 2.77ms 97: learn: 0.0865905 total: 90ms remaining: 1.83ms 98: learn: 0.0856569 total: 90.4ms remaining: 912us 99: learn: 0.0847351 total: 90.8ms remaining: 0us Results from Grid Search The best estimator across ALL searched params:The best score across ALL searched params: 0.9443438334960408 The best parameters across ALL searched params: {'depth': 6, 'iterations': 100, 'learning_rate': 0.05}