How to create and optimize a baseline Decision Tree model for Binary Classification in R?

This recipe helps you create and optimize a baseline Decision Tree model for Binary Classification in R

Recipe Objective

Decision Tree is a supervised machine learning algorithm which can be used to perform both classification and regression on complex datasets. They are also known as Classification and Regression Trees (CART). Hence, it works for both continuous and categorical variables.

Important basic tree Terminology is as follows: ​

  1. Root node: represents an entire popuplation or dataset which gets divided into two or more pure sets (also known as homogeneuos steps). It always contains a single input variable (x).
  2. Leaf or terminal node: These nodes do not split further and contains the output variable

In this recipe, we will only focus on Classification Trees where the target variable is categorical in nature. The splits in these trees are based on the homogeneity of the groups formed. The homogeinity or impurity in the data is quantified by computing metrics like Entropy, Information Gain and Gini Index. ​

Most commonly used Metric is Information gain. It is the measure to quantify how much information a feature variable provides about the class. ​

This recipe demonstrates the modelling and optimising of a Classification Tree for Binary classification, we use a famous dataset by National institute of Diabetes and Digestive and Kidney Diseases. ​

STEP 1: Importing Necessary Libraries

library(caret) library(tidyverse) # for data manipulation

STEP 2: Read a csv file and explore the data

Data Description: This datasets consist of several medical predictor variables (also known as the independent variables) and one target variable (Outcome).

Independent Variables: ​

  1. Pregnancies
  2. Glucose
  3. BloodPressure
  4. SkinThickness
  5. Insulin
  6. BMI
  7. DiabetesPedigreeFunction
  8. Age

Dependent Variable: ​

Outcome ( 0 = 'does not have diabetes', 1 = 'Has diabetes') ​

data <- read.csv("R_344_diabetes.csv") glimpse(data)
Rows: 768
Columns: 9
$ Pregnancies               6, 1, 8, 1, 0, 5, 3, 10, 2, 8, 4, 10, 10, ...
$ Glucose                   148, 85, 183, 89, 137, 116, 78, 115, 197, ...
$ BloodPressure             72, 66, 64, 66, 40, 74, 50, 0, 70, 96, 92,...
$ SkinThickness             35, 29, 0, 23, 35, 0, 32, 0, 45, 0, 0, 0, ...
$ Insulin                   0, 0, 0, 94, 168, 0, 88, 0, 543, 0, 0, 0, ...
$ BMI                       33.6, 26.6, 23.3, 28.1, 43.1, 25.6, 31.0, ...
$ DiabetesPedigreeFunction  0.627, 0.351, 0.672, 0.167, 2.288, 0.201, ...
$ Age                       50, 31, 32, 21, 33, 30, 26, 29, 53, 54, 30...
$ Outcome                   1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, ...
summary(data) # returns the statistical summary of the data columns
Pregnancies        Glucose      BloodPressure    SkinThickness  
 Min.   : 0.000   Min.   :  0.0   Min.   :  0.00   Min.   : 0.00  
 1st Qu.: 1.000   1st Qu.: 99.0   1st Qu.: 62.00   1st Qu.: 0.00  
 Median : 3.000   Median :117.0   Median : 72.00   Median :23.00  
 Mean   : 3.845   Mean   :120.9   Mean   : 69.11   Mean   :20.54  
 3rd Qu.: 6.000   3rd Qu.:140.2   3rd Qu.: 80.00   3rd Qu.:32.00  
 Max.   :17.000   Max.   :199.0   Max.   :122.00   Max.   :99.00  
    Insulin           BMI        DiabetesPedigreeFunction      Age       
 Min.   :  0.0   Min.   : 0.00   Min.   :0.0780           Min.   :21.00  
 1st Qu.:  0.0   1st Qu.:27.30   1st Qu.:0.2437           1st Qu.:24.00  
 Median : 30.5   Median :32.00   Median :0.3725           Median :29.00  
 Mean   : 79.8   Mean   :31.99   Mean   :0.4719           Mean   :33.24  
 3rd Qu.:127.2   3rd Qu.:36.60   3rd Qu.:0.6262           3rd Qu.:41.00  
 Max.   :846.0   Max.   :67.10   Max.   :2.4200           Max.   :81.00  
    Outcome     
 Min.   :0.000  
 1st Qu.:0.000  
 Median :0.000  
 Mean   :0.349  
 3rd Qu.:1.000  
 Max.   :1.000  
dim(data)
768 9
# Converting the dependent variable into factor levels data$Outcome = as.factor(data$Outcome)

STEP 3: Train Test Split

# createDataPartition() function from the caret package to split the original dataset into a training and testing set and split data into training (80%) and testing set (20%) parts = createDataPartition(data$Cost, p = .8, list = F) train = data[parts, ] test = data[-parts, ]

STEP 4: Building and optimising Baseline Regression Tree

We will use caret package to perform Cross Validation and Hyperparameter tuning (max_depth) using grid search technique. First, we will use the trainControl() function to define the method of cross validation to be carried out and search type i.e. "grid" or "random". Then train the model using train() function with tuneGrid as one of the arguements.

Syntax: train(formula, data = , method = , trControl = , tuneGrid = )

where:

  1. formula = y~x1+x2+x3+..., where y is the independent variable and x1,x2,x3 are the dependent variables
  2. data = dataframe
  3. method = Type of the model to be built ("rpart2" for CART)
  4. trControl = Takes the control parameters. We will use trainControl function out here where we will specify the Cross validation technique.
  5. tuneGrid = takes the tuning parameters and applies grid search CV on them
# specifying the CV technique which will be passed into the train() function later and number parameter is the "k" in K-fold cross validation train_control = trainControl(method = "cv", number = 5, search = "grid") ## Customsing the tuning grid (ridge regression has alpha = 0) classification_Tree_Grid = expand.grid(maxdepth = c(1,3,5,7,9)) set.seed(50) # training a Regression model while tuning parameters (Method = "rpart") model = train(Outcome~., data = train, method = "rpart2", trControl = train_control, tuneGrid = classification_Tree_Grid) # summarising the results print(model)
CART 

615 samples
  8 predictor
  2 classes: '0', '1' 

No pre-processing
Resampling: Cross-Validated (5 fold) 
Summary of sample sizes: 492, 492, 492, 492, 492 
Resampling results across tuning parameters:

  maxdepth  Accuracy   Kappa    
  1         0.7252033  0.3790189
  3         0.7317073  0.3941319
  5         0.7447154  0.4314799
  7         0.7430894  0.4193752
  9         0.7284553  0.3993505

Accuracy was used to select the optimal model using the largest value.
The final value used for the model was maxdepth = 5.

Note: Accuracy was used select the optimal model using the smallest value. And the final model has the max depth of 5.

STEP 5: Make predictions on the final classification Tree model

We use our final classification Tree model to make predictions on the testing data (unseen data) and predict the 'Outcome' value and generate performance measures.

#use model to make predictions on test data pred_y = predict(model, test) # confusion Matrix confusionMatrix(data = pred_y, test$Outcome)
Confusion Matrix and Statistics

          Reference
Prediction  0  1
         0 88 23
         1 12 30
                                          
               Accuracy : 0.7712          
                 95% CI : (0.6965, 0.8352)
    No Information Rate : 0.6536          
    P-Value [Acc > NIR] : 0.001098        
                                          
                  Kappa : 0.4689          
                                          
 Mcnemar's Test P-Value : 0.090969        
                                          
            Sensitivity : 0.8800          
            Specificity : 0.5660          
         Pos Pred Value : 0.7928          
         Neg Pred Value : 0.7143          
             Prevalence : 0.6536          
         Detection Rate : 0.5752          
   Detection Prevalence : 0.7255          
      Balanced Accuracy : 0.7230          
                                          
       'Positive' Class : 0      

What Users are saying..

profile image

Ed Godalle

Director Data Analytics at EY / EY Tech
linkedin profile url

I am the Director of Data Analytics with over 10+ years of IT experience. I have a background in SQL, Python, and Big Data working with Accenture, IBM, and Infosys. I am looking to enhance my skills... Read More

Relevant Projects

Multilabel Classification Project for Predicting Shipment Modes
Multilabel Classification Project to build a machine learning model that predicts the appropriate mode of transport for each shipment, using a transport dataset with 2000 unique products. The project explores and compares four different approaches to multilabel classification, including naive independent models, classifier chains, natively multilabel models, and multilabel to multiclass approaches.

MLOps Project to Deploy Resume Parser Model on Paperspace
In this MLOps project, you will learn how to deploy a Resume Parser Streamlit Application on Paperspace Private Cloud.

Build Classification Algorithms for Digital Transformation[Banking]
Implement a machine learning approach using various classification techniques in Python to examine the digitalisation process of bank customers.

Learn How to Build a Linear Regression Model in PyTorch
In this Machine Learning Project, you will learn how to build a simple linear regression model in PyTorch to predict the number of days subscribed.

Skip Gram Model Python Implementation for Word Embeddings
Skip-Gram Model word2vec Example -Learn how to implement the skip gram algorithm in NLP for word embeddings on a set of documents.

BigMart Sales Prediction ML Project in Python
The goal of the BigMart Sales Prediction ML project is to build and evaluate different predictive models and determine the sales of each product at a store.

Build an optimal End-to-End MLOps Pipeline and Deploy on GCP
Learn how to build and deploy an end-to-end optimal MLOps Pipeline for Loan Eligibility Prediction Model in Python on GCP

Deploy Transformer-BART Model on Paperspace Cloud
In this MLOps Project you will learn how to deploy a Tranaformer BART Model for Abstractive Text Summarization on Paperspace Private Cloud

Mastering A/B Testing: A Practical Guide for Production
In this A/B Testing for Machine Learning Project, you will gain hands-on experience in conducting A/B tests, analyzing statistical significance, and understanding the challenges of building a solution for A/B testing in a production environment.

Build a Graph Based Recommendation System in Python -Part 1
Python Recommender Systems Project - Learn to build a graph based recommendation system in eCommerce to recommend products.