How to create and optimize a baseline Decision Tree model for MultiClass Classification in R?

This recipe helps you create and optimize a baseline Decision Tree model for MultiClass Classification in R

Recipe Objective

Decision Tree is a supervised machine learning algorithm which can be used to perform both classification and regression on complex datasets. They are also known as Classification and Regression Trees (CART). Hence, it works for both continuous and categorical variables.

Important basic tree Terminology is as follows: ​

  1. Root node: represents an entire popuplation or dataset which gets divided into two or more pure sets (also known as homogeneuos steps). It always contains a single input variable (x).
  2. Leaf or terminal node: These nodes do not split further and contains the output variable

In this recipe, we will only focus on Classification Trees where the target variable is categorical in nature. The splits in these trees are based on the homogeneity of the groups formed. The homogeinity or impurity in the data is quantified by computing metrics like Entropy, Information Gain and Gini Index. ​

Most commonly used Metric is Information gain. It is the measure to quantify how much information a feature variable provides about the class. ​

This recipe demonstrates the modelling and optimising of a Classification Tree for Multi-class classification, we use the iris in-built dataset ​

STEP 1: Importing Necessary Libraries

library(caret) library(tidyverse) # for data manipulation

STEP 2: Read a csv file and explore the data

Data Description: The dataset consists of 50 samples from each of the three species of flower (setosa, virginica, versicolor)

Independent Variables: ​

  1. petal_length
  2. petal_width
  3. sepal_length
  4. sepal_width

Dependent Variable: ​

label (iris-setosa, iris - versicular, iris - virginica) ​

data <- read.csv("R_345_Iris.csv") glimpse(data)

Rows: 150
Columns: 5
$ petal_length  5.1, 4.9, 4.7, 4.6, 5.0, 5.4, 4.6, 5.0, 4.4, 4.9, 5.4,...
$ petal_width   3.5, 3.0, 3.2, 3.1, 3.6, 3.9, 3.4, 3.4, 2.9, 3.1, 3.7,...
$ sepal_length  1.4, 1.4, 1.3, 1.5, 1.4, 1.7, 1.4, 1.5, 1.4, 1.5, 1.5,...
$ sepal_width   0.2, 0.2, 0.2, 0.2, 0.2, 0.4, 0.3, 0.2, 0.2, 0.1, 0.2,...
$ label         Iris-setosa, Iris-setosa, Iris-setosa, Iris-setosa, Ir...

summary(data) # returns the statistical summary of the data columns

petal_length    petal_width     sepal_length    sepal_width   
 Min.   :4.300   Min.   :2.000   Min.   :1.000   Min.   :0.100  
 1st Qu.:5.100   1st Qu.:2.800   1st Qu.:1.600   1st Qu.:0.300  
 Median :5.800   Median :3.000   Median :4.350   Median :1.300  
 Mean   :5.843   Mean   :3.054   Mean   :3.759   Mean   :1.199  
 3rd Qu.:6.400   3rd Qu.:3.300   3rd Qu.:5.100   3rd Qu.:1.800  
 Max.   :7.900   Max.   :4.400   Max.   :6.900   Max.   :2.500  
             label   
 Iris-setosa    :50  
 Iris-versicolor:50  
 Iris-virginica :50   

dim(data)

150 5

# Converting the dependent variable into factor levels data$label = as.factor(data$label)

STEP 3: Train Test Split

# createDataPartition() function from the caret package to split the original dataset into a training and testing set and split data into training (80%) and testing set (20%) parts = createDataPartition(data$Cost, p = .8, list = F) train = data[parts, ] test = data[-parts, ]

STEP 4: Building and optimising Baseline Classification Tree for multi-class classification

We will use caret package to perform Cross Validation and Hyperparameter tuning (max_depth) using grid search technique. First, we will use the trainControl() function to define the method of cross validation to be carried out and search type i.e. "grid" or "random". Then train the model using train() function with tuneGrid as one of the arguements.

Syntax: train(formula, data = , method = , trControl = , tuneGrid = )

where:

  1. formula = y~x1+x2+x3+..., where y is the independent variable and x1,x2,x3 are the dependent variables
  2. data = dataframe
  3. method = Type of the model to be built ("rpart2" for CART)
  4. trControl = Takes the control parameters. We will use trainControl function out here where we will specify the Cross validation technique.
  5. tuneGrid = takes the tuning parameters and applies grid search CV on them

# specifying the CV technique which will be passed into the train() function later and number parameter is the "k" in K-fold cross validation train_control = trainControl(method = "cv", number = 5, search = "grid") ## Customsing the tuning grid (ridge regression has alpha = 0) multi_classification_Tree_Grid = expand.grid(maxdepth = c(1,3,5,7,9)) set.seed(50) # training a Regression model while tuning parameters (Method = "rpart") model = train(label~., data = train, method = "rpart2", trControl = train_control, tuneGrid = multi_classification_Tree_Grid) # summarising the results print(model)

CART 

120 samples
  4 predictor
  3 classes: 'Iris-setosa', 'Iris-versicolor', 'Iris-virginica' 

No pre-processing
Resampling: Cross-Validated (5 fold) 
Summary of sample sizes: 96, 96, 96, 96, 96 
Resampling results across tuning parameters:

  maxdepth  Accuracy   Kappa
  1         0.6666667  0.500
  3         0.9166667  0.875
  5         0.9166667  0.875
  7         0.9166667  0.875
  9         0.9166667  0.875

Accuracy was used to select the optimal model using the largest value.
The final value used for the model was maxdepth = 3.

Note: Accuracy was used select the optimal model using the smallest value. And the final model has the max depth of 3.

STEP 5: Make predictions on the final classification Tree model

We use our final classification Tree model to make predictions on the testing data (unseen data) and predict the 'label' value and generate Confusion matrix.

#use model to make predictions on test data pred_y = predict(model, test) # confusion Matrix confusionMatrix(data = pred_y, test$label)

Confusion Matrix and Statistics

                 Reference
Prediction        Iris-setosa Iris-versicolor Iris-virginica
  Iris-setosa              10               0              0
  Iris-versicolor           0              10              0
  Iris-virginica            0               0             10

Overall Statistics
                                     
               Accuracy : 1          
                 95% CI : (0.8843, 1)
    No Information Rate : 0.3333     
    P-Value [Acc > NIR] : 4.857e-15  
                                     
                  Kappa : 1          
                                     
 Mcnemar's Test P-Value : NA         

Statistics by Class:

                     Class: Iris-setosa Class: Iris-versicolor
Sensitivity                      1.0000                 1.0000
Specificity                      1.0000                 1.0000
Pos Pred Value                   1.0000                 1.0000
Neg Pred Value                   1.0000                 1.0000
Prevalence                       0.3333                 0.3333
Detection Rate                   0.3333                 0.3333
Detection Prevalence             0.3333                 0.3333
Balanced Accuracy                1.0000                 1.0000
                     Class: Iris-virginica
Sensitivity                         1.0000
Specificity                         1.0000
Pos Pred Value                      1.0000
Neg Pred Value                      1.0000
Prevalence                          0.3333
Detection Rate                      0.3333
Detection Prevalence                0.3333
Balanced Accuracy                   1.0000  

What Users are saying..

profile image

Savvy Sahai

Data Science Intern, Capgemini
linkedin profile url

As a student looking to break into the field of data engineering and data science, one can get really confused as to which path to take. Very few ways to do it are Google, YouTube, etc. I was one of... Read More

Relevant Projects

Build a Multi Touch Attribution Machine Learning Model in Python
Identifying the ROI on marketing campaigns is an essential KPI for any business. In this ML project, you will learn to build a Multi Touch Attribution Model in Python to identify the ROI of various marketing efforts and their impact on conversions or sales..

Deploying Machine Learning Models with Flask for Beginners
In this MLOps on GCP project you will learn to deploy a sales forecasting ML Model using Flask.

Time Series Forecasting Project-Building ARIMA Model in Python
Build a time series ARIMA model in Python to forecast the use of arrival rate density to support staffing decisions at call centres.

Linear Regression Model Project in Python for Beginners Part 2
Machine Learning Linear Regression Project for Beginners in Python to Build a Multiple Linear Regression Model on Soccer Player Dataset.

Customer Market Basket Analysis using Apriori and Fpgrowth algorithms
In this data science project, you will learn how to perform market basket analysis with the application of Apriori and FP growth algorithms based on the concept of association rule learning.

Build CI/CD Pipeline for Machine Learning Projects using Jenkins
In this project, you will learn how to create a CI/CD pipeline for a search engine application using Jenkins.

Build CNN Image Classification Models for Real Time Prediction
Image Classification Project to build a CNN model in Python that can classify images into social security cards, driving licenses, and other key identity information.

Predictive Analytics Project for Working Capital Optimization
In this Predictive Analytics Project, you will build a model to accurately forecast the timing of customer and supplier payments for optimizing working capital.

Deep Learning Project for Time Series Forecasting in Python
Deep Learning for Time Series Forecasting in Python -A Hands-On Approach to Build Deep Learning Models (MLP, CNN, LSTM, and a Hybrid Model CNN-LSTM) on Time Series Data.

Recommender System Machine Learning Project for Beginners-3
Content Based Recommender System Project - Building a Content-Based Product Recommender App with Streamlit