Recipe: How to plot Validation Curve in Python?
DATA VISUALIZATION

How to plot Validation Curve in Python?

This recipe helps you plot Validation Curve in Python
In [2]:
## How to plot Validation Curve in Python
def Snippet_141():
    print()
    print(format('How to plot Validation Curve in Python','*^82'))

    import warnings
    warnings.filterwarnings("ignore")

    # load libraries
    import matplotlib.pyplot as plt
    import numpy as np
    from sklearn.datasets import load_digits
    from sklearn.ensemble import RandomForestClassifier
    from sklearn.model_selection import validation_curve

    # Load data
    digits = load_digits()

    # Create feature matrix and target vector
    X, y = digits.data, digits.target

    # Plot Validation Curve
    # Create range of values for parameter
    param_range = np.arange(1, 250, 2)

    # Calculate accuracy on training and test set using range of parameter values
    train_scores, test_scores = validation_curve(RandomForestClassifier(),
                                  X, y, param_name="n_estimators", param_range=param_range,
                                  cv=4, scoring="accuracy", n_jobs=-1)

    # Calculate mean and standard deviation for training set scores
    train_mean = np.mean(train_scores, axis=1)
    train_std = np.std(train_scores, axis=1)

    # Calculate mean and standard deviation for test set scores
    test_mean = np.mean(test_scores, axis=1)
    test_std = np.std(test_scores, axis=1)

    # Plot mean accuracy scores for training and test sets
    plt.subplots(1, figsize=(7,7))
    plt.plot(param_range, train_mean, label="Training score", color="black")
    plt.plot(param_range, test_mean, label="Cross-validation score", color="dimgrey")

    # Plot accurancy bands for training and test sets
    plt.fill_between(param_range, train_mean - train_std, train_mean + train_std, color="gray")
    plt.fill_between(param_range, test_mean - test_std, test_mean + test_std, color="gainsboro")

    # Create plot    
    plt.title("Validation Curve With Random Forest")
    plt.xlabel("Number Of Trees")
    plt.ylabel("Accuracy Score")
    plt.tight_layout()
    plt.legend(loc="best")
    plt.show()

Snippet_141()
**********************How to plot Validation Curve in Python**********************


Stuck at work?
Can't find the recipe you are looking for. Let us know and we will find an expert to create the recipe for you. Click here
Companies using this Recipe
1 developer from HvH
1 developer from Safaricom
1 developer from Amazon
1 developer from ICU Medical
1 developer from Vodafone
1 developer from ANAC
1 developer from Inventum
1 developer from Emids
1 developer from Ness Technologies