How to plot Validation Curve in Python?

This recipe helps you plot Validation Curve in Python

Recipe Objective

While working on a dataset we train a model and check its accuracy, if we check the accuracy on the data which we have used for training then the accuracy comes out to be very high because the model have already seen the data. So for real testing we have check the accuracy on unseen data for different parameters of model to get a better view.

This data science python source code does the following:
1. Imports Digit dataset and necessary libraries
2. Imports validation curve function for visualization
3. Splits dataset into train and test
4. Plots graphs using matplotlib to analyze the validation of the model

So this is the recipe on how to use validation curve and we will plot the validation curve.

Get Closer To Your Dream of Becoming a Data Scientist with 70+ Solved End-to-End ML Projects

Step 1 - Import the library

import matplotlib.pyplot as plt import numpy as np from sklearn import datasets from sklearn.ensemble import RandomForestClassifier from sklearn.model_selection import validation_curve

We have imported all the modules that would be needed like numpy, datasets, RandomForestClassifier and validation_curve. We will see the use of each modules step by step further.

Step 2 - Setting up the Data

We have imported inbuilt iris dataset from the module datasets and stored the data in X and the target in y. digits = datasets.load_iris() X, y = digits.data, digits.target

Step 3 - Using Validation_Curve and calculating the scores

Here we are using RandomForestClassifier so first we have to define a object for the range of parameters on which we have to use the validation curve. So we have created an object param_range for that.

Now before using Validation curve, let us first see its parameters:

    • estimator : In this we have to pass the metric or the model for which we need to optimize the parameters.
    • param_name : In this we have to pass the names of parameters on which we have to use the validation curve.

li

      >

param_range

    : In this we have to pass the range of values of parameter on which we have to use the validation curve.
  • cv : In this we have to pass a interger value, as it signifies the number of splits that is needed for cross validation. By default is set as five.
  • scoring : This signifies the metric of calculating the score.
  • n_jobs : This signifies the number of jobs to be run in parallel, -1 signifies to use all processor.

param_range = np.arange(1, 250, 2) train_scores, test_scores = validation_curve(RandomForestClassifier(), X, y, param_name="n_estimators", param_range=param_range, cv=4, scoring="accuracy", n_jobs=-1)

Now we are calculating the mean and standard deviation of the training and testing scores. train_mean = np.mean(train_scores, axis=1) train_std = np.std(train_scores, axis=1) test_mean = np.mean(test_scores, axis=1) test_std = np.std(test_scores, axis=1)

 

Explore More Data Science and Machine Learning Projects for Practice. Fast-Track Your Career Transition with ProjectPro

Step 4 - Ploting the validation curve

First we are plotting the mean accuracy scores for both the training and the testing set. Then the accuracy band for the training and testing sets. Finally the few lines is of the other setting like size , legend etc for the plot. plt.subplots(1, figsize=(7,7)) plt.plot(param_range, train_mean, label="Training score", color="black") plt.plot(param_range, test_mean, label="Cross-validation score", color="dimgrey") plt.fill_between(param_range, train_mean - train_std, train_mean + train_std, color="gray") plt.fill_between(param_range, test_mean - test_std, test_mean + test_std, color="gainsboro") plt.title("Validation Curve With Random Forest") plt.xlabel("Number Of Trees") plt.ylabel("Accuracy Score") plt.tight_layout() plt.legend(loc="best") plt.show()

Join Millions of Satisfied Developers and Enterprises to Maximize Your Productivity and ROI with ProjectPro - Read ProjectPro Reviews Now!

 

 

Download Materials

What Users are saying..

profile image

Gautam Vermani

Data Consultant at Confidential
linkedin profile url

Having worked in the field of Data Science, I wanted to explore how I can implement projects in other domains, So I thought of connecting with ProjectPro. A project that helped me absorb this topic... Read More

Relevant Projects

End-to-End ML Model Monitoring using Airflow and Docker
In this MLOps Project, you will learn to build an end to end pipeline to monitor any changes in the predictive power of model or degradation of data.

AWS MLOps Project to Deploy Multiple Linear Regression Model
Build and Deploy a Multiple Linear Regression Model in Python on AWS

Many-to-One LSTM for Sentiment Analysis and Text Generation
In this LSTM Project , you will build develop a sentiment detection model using many-to-one LSTMs for accurate prediction of sentiment labels in airline text reviews. Additionally, we will also train many-to-one LSTMs on 'Alice's Adventures in Wonderland' to generate contextually relevant text.

Isolation Forest Model and LOF for Anomaly Detection in Python
Credit Card Fraud Detection Project - Build an Isolation Forest Model and Local Outlier Factor (LOF) in Python to identify fraudulent credit card transactions.

PyTorch Project to Build a LSTM Text Classification Model
In this PyTorch Project you will learn how to build an LSTM Text Classification model for Classifying the Reviews of an App .

Recommender System Machine Learning Project for Beginners-3
Content Based Recommender System Project - Building a Content-Based Product Recommender App with Streamlit

Abstractive Text Summarization using Transformers-BART Model
Deep Learning Project to implement an Abstractive Text Summarizer using Google's Transformers-BART Model to generate news article headlines.

Build Regression (Linear,Ridge,Lasso) Models in NumPy Python
In this machine learning regression project, you will learn to build NumPy Regression Models (Linear Regression, Ridge Regression, Lasso Regression) from Scratch.

Ola Bike Rides Request Demand Forecast
Given big data at taxi service (ride-hailing) i.e. OLA, you will learn multi-step time series forecasting and clustering with Mini-Batch K-means Algorithm on geospatial data to predict future ride requests for a particular region at a given time.

Deep Learning Project for Text Detection in Images using Python
CV2 Text Detection Code for Images using Python -Build a CRNN deep learning model to predict the single-line text in a given image.