How does K fold cross validation work in R?

How does K fold cross validation work in R

Recipe Objective

The major challenge when building a model is to make it work accurately on unseen data. Cross-Validation is one of the techniques which can be used to check the effectiveness of the model. It reserves a portion of the data which is not used while training the model. It is rather used later as unseen data to test/validate the model by giving us the prediction error.

German Credit Card Dataset Analysis

The three most common Cross-Validation Techniques are: ​

  1. Leave one out cross-validation (LOOC)
  2. K-fold cross-validation
  3. repeated k-fold cross validation.

In this recipe, we will learn how to use perform K-fold Cross Validation while building a linear regression model R. ​

K-fold cross validation technique splits the dataset into 'k' folds or subsets. In this technique, the following steps takes place: ​

  1. Randomly split the data into k folds or subsets
  2. Train the model on the whole dataset except leaving out one subset.
  3. Testing the model against that one left out subset.
  4. The above three steps are repeated until the every subset is used to train and test the model.
  5. The final prediction error obtained is the average of the errors in every case.

The k value of 5 or 10 is typically chosen as it tends to avoid excessively high bias and variance simulataneously. This technique is a full proof technique which does not leave out important data when building the model. ​

STEP 1: Importing Necessary Libraries

library(caret) library(tidyverse) # for data manipulation

STEP 2: Read a csv file and explore the data

The dataset attached contains the data of 160 different bags associated with ABC industries.

The bags have certain attributes which are described below: ​

  1. Height – The height of the bag
  2. Width – The width of the bag
  3. Length – The length of the bag
  4. Weight – The weight the bag can carry
  5. Weight1 – Weight the bag can carry after expansion

The company now wants to predict the cost they should set for a new variant of these kinds of bags. ​

data <- read.csv("R_303_Data_1.csv") glimpse(data)

Rows: 159
Columns: 6
$ Cost     242, 290, 340, 363, 430, 450, 500, 390, 450, 500, 475, 500,...
$ Weight   23.2, 24.0, 23.9, 26.3, 26.5, 26.8, 26.8, 27.6, 27.6, 28.5,...
$ Weight1  25.4, 26.3, 26.5, 29.0, 29.0, 29.7, 29.7, 30.0, 30.0, 30.7,...
$ Length   30.0, 31.2, 31.1, 33.5, 34.0, 34.7, 34.5, 35.0, 35.1, 36.2,...
$ Height   11.5200, 12.4800, 12.3778, 12.7300, 12.4440, 13.6024, 14.17...
$ Width    4.0200, 4.3056, 4.6961, 4.4555, 5.1340, 4.9274, 5.2785, 4.6...

summary(data) # returns the statistical summary of the data columns

Cost            Weight         Weight1          Length     
 Min.   :   0.0   Min.   : 7.50   Min.   : 8.40   Min.   : 8.80  
 1st Qu.: 120.0   1st Qu.:19.05   1st Qu.:21.00   1st Qu.:23.15  
 Median : 273.0   Median :25.20   Median :27.30   Median :29.40  
 Mean   : 398.3   Mean   :26.25   Mean   :28.42   Mean   :31.23  
 3rd Qu.: 650.0   3rd Qu.:32.70   3rd Qu.:35.50   3rd Qu.:39.65  
 Max.   :1650.0   Max.   :59.00   Max.   :63.40   Max.   :68.00  
     Height           Width      
 Min.   : 1.728   Min.   :1.048  
 1st Qu.: 5.945   1st Qu.:3.386  
 Median : 7.786   Median :4.248  
 Mean   : 8.971   Mean   :4.417  
 3rd Qu.:12.366   3rd Qu.:5.585  
 Max.   :18.957   Max.   :8.142   

dim(data)

159 6

STEP 3: Performing K-Fold Cross validation

We will use caret package to perform Cross Validation. Firstly, we will use the trainControl() function to define the method of cross validation to be carried out and then use train() function.

Syntax: train(formula, data = , method = , trControl = , tuneGrid = )

where:

  1. formula = y~x1+x2+x3+..., where y is the independent variable and x1,x2,x3 are the dependent variables
  2. data = dataframe
  3. method = Type of the model to be built
  4. trControl = Takes the control parameters. We will use trainControl function out here where we will specify the Cross validation technique.

# specifying the CV technique which will be passed into the train() function later train_control = trainControl(method = "cv", number = 5) # training a linear regression model with LOOCV model = train(Cost~., data = data_1, method = "lm", trControl = train_control) # summarising the results print(model)

Linear Regression 

159 samples
  5 predictor

No pre-processing
Resampling: Cross-Validated (5 fold) 
Summary of sample sizes: 128, 128, 127, 126, 127 
Resampling results:

  RMSE      Rsquared   MAE     
  123.8312  0.8897412  93.41378

Tuning parameter 'intercept' was held constant at a value of TRUE

Note: The averaged RMSE, R-squared and MAE mentioned above is the cross validation error.

What Users are saying..

profile image

Ed Godalle

Director Data Analytics at EY / EY Tech
linkedin profile url

I am the Director of Data Analytics with over 10+ years of IT experience. I have a background in SQL, Python, and Big Data working with Accenture, IBM, and Infosys. I am looking to enhance my skills... Read More

Relevant Projects

Build a Logistic Regression Model in Python from Scratch
Regression project to implement logistic regression in python from scratch on streaming app data.

FEAST Feature Store Example for Scaling Machine Learning
FEAST Feature Store Example- Learn to use FEAST Feature Store to manage, store, and discover features for customer churn prediction machine learning project.

Build a Music Recommendation Algorithm using KKBox's Dataset
Music Recommendation Project using Machine Learning - Use the KKBox dataset to predict the chances of a user listening to a song again after their very first noticeable listening event.

Learn How to Build PyTorch Neural Networks from Scratch
In this deep learning project, you will learn how to build PyTorch neural networks from scratch.

Time Series Forecasting with LSTM Neural Network Python
Deep Learning Project- Learn to apply deep learning paradigm to forecast univariate time series data.

Image Classification Model using Transfer Learning in PyTorch
In this PyTorch Project, you will build an image classification model in PyTorch using the ResNet pre-trained model.

LLM Project to Build and Fine Tune a Large Language Model
In this LLM project for beginners, you will learn to build a knowledge-grounded chatbot using LLM's and learn how to fine tune it.

Deep Learning Project for Time Series Forecasting in Python
Deep Learning for Time Series Forecasting in Python -A Hands-On Approach to Build Deep Learning Models (MLP, CNN, LSTM, and a Hybrid Model CNN-LSTM) on Time Series Data.

Time Series Project to Build a Multiple Linear Regression Model
Learn to build a Multiple linear regression model in Python on Time Series Data

Build OCR from Scratch Python using YOLO and Tesseract
In this deep learning project, you will learn how to build your custom OCR (optical character recognition) from scratch by using Google Tesseract and YOLO to read the text from any images.