How to visualize decision trees in R?

This recipe helps you visualize decision trees in R

Recipe Objective

Decision Tree is a supervised machine learning algorithm which can be used to perform both classification and regression on complex datasets. They are also known as Classification and Regression Trees (CART). Hence, it works for both continuous and categorical variables.

Important basic tree Terminology is as follows: ​

  1. Root node: represents an entire popuplation or dataset which gets divided into two or more pure sets (also known as homogeneuos steps). It always contains a single input variable (x).
  2. Leaf or terminal node: These nodes do not split further and contains the output variable

In this recipe, we will only focus on building and visualising the Regression Tree where the target variable is continuous in nature. ​

STEP 1: Importing Necessary Libraries

# For data manipulation library(tidyverse) # For Decision Tree algorithm library(rpart) # for plotting the decision Tree install.packages("rpart.plot") library(rpart.plot) # Install readxl R package for reading excel sheets install.packages("readxl") library("readxl")

STEP 2: Loading the Train and Test Dataset

Loading the test and train dataset sepearately. Here Train and test are split in 80/20 proportion respectively.

Dataset description: The company wants to predict the cost they should set for a new variant of the kinds of bags based on the attributes mentioned below using the following variables: ​

  1. Height – The height of the bag
  2. Width – The width of the bag
  3. Length – The length of the bag
  4. Weight – The weight the bag can carry
  5. Weight1 – Weight the bag can carry after expansion

# calling the function read_excel from the readxl library train = read_excel('R_285_df_train_regression.xlsx') # gives the number of observations and variables involved with its brief description glimpse(train)

Rows: 127
Columns: 6
$ Cost     242, 290, 340, 363, 430, 450, 500, 390, 450, 500, 475, 500,...
$ Weight   23.2, 24.0, 23.9, 26.3, 26.5, 26.8, 26.8, 27.6, 27.6, 28.5,...
$ Weight1  25.4, 26.3, 26.5, 29.0, 29.0, 29.7, 29.7, 30.0, 30.0, 30.7,...
$ Length   30.0, 31.2, 31.1, 33.5, 34.0, 34.7, 34.5, 35.0, 35.1, 36.2,...
$ Height   11.5200, 12.4800, 12.3778, 12.7300, 12.4440, 13.6024, 14.17...
$ Width    4.0200, 4.3056, 4.6961, 4.4555, 5.1340, 4.9274, 5.2785, 4.6...

STEP 3: Data Preprocessing (Scaling)

This is a pre-modelling step. In this step, the data must be scaled or standardised so that different attributes can be comparable. Standardised data has mean zero and standard deviation one. we do thiis using scale() function.

Note: Scaling is an important pre-modelling step which has to be mandatory

# scaling the independent variables in train dataset train_scaled = scale(train[2:6]) # using cbind() function to add a new column Outcome to the scaled independent values train_scaled = data.frame(cbind(train_scaled, Outcome = train$Cost)) train_scaled %>% head()

Weight		Weight1		Length		Height		Width		Outcome
-0.33379271	-0.3132781	-0.08858827	0.4095324	-0.42466337	242
-0.22300101	-0.1970948	0.04945726	0.6459374	-0.22972408	290
-0.23684997	-0.1712763	0.03795346	0.6207701	0.03681581	340
0.09552513	0.1514550	0.31404453	0.7075012	-0.12740825	363
0.12322305	0.1514550	0.37156350	0.6370722	0.33570907	430
0.16476994	0.2418198	0.45209006	0.9223343	0.19469206	450

STEP 4: Creation of Decision Tree Regressor model using training set

We use rpart() function to fit the model.

Syntax: rpart(formula, data = , method = '')

Where:

  1. Formula of the Decision Trees: Outcome ~. where Outcome is dependent variable and . represents all other independent variables
  2. data = train_scaled
  3. method = 'anova' (to Fit a regression model)

# creation of an object 'model' using rpart function model = rpart(Outcome~., data = train_scaled, method = 'anova')

STEP 5: Visualising a Decision tree

Using rpart.plot() function to plot the decision tree model

Using rpart.plot() function to plot the decision tree model

Visit rpart-plot rpart.plot(model)

What Users are saying..

profile image

Anand Kumpatla

Sr Data Scientist @ Doubleslash Software Solutions Pvt Ltd
linkedin profile url

ProjectPro is a unique platform and helps many people in the industry to solve real-life problems with a step-by-step walkthrough of projects. A platform with some fantastic resources to gain... Read More

Relevant Projects

Digit Recognition using CNN for MNIST Dataset in Python
In this deep learning project, you will build a convolutional neural network using MNIST dataset for handwritten digit recognition.

Build Classification Algorithms for Digital Transformation[Banking]
Implement a machine learning approach using various classification techniques in Python to examine the digitalisation process of bank customers.

Build an End-to-End AWS SageMaker Classification Model
MLOps on AWS SageMaker -Learn to Build an End-to-End Classification Model on SageMaker to predict a patient’s cause of death.

Deep Learning Project for Beginners with Source Code Part 1
Learn to implement deep neural networks in Python .

Build OCR from Scratch Python using YOLO and Tesseract
In this deep learning project, you will learn how to build your custom OCR (optical character recognition) from scratch by using Google Tesseract and YOLO to read the text from any images.

Recommender System Machine Learning Project for Beginners-1
Recommender System Machine Learning Project for Beginners - Learn how to design, implement and train a rule-based recommender system in Python

Skip Gram Model Python Implementation for Word Embeddings
Skip-Gram Model word2vec Example -Learn how to implement the skip gram algorithm in NLP for word embeddings on a set of documents.

Model Deployment on GCP using Streamlit for Resume Parsing
Perform model deployment on GCP for resume parsing model using Streamlit App.

House Price Prediction Project using Machine Learning in Python
Use the Zillow Zestimate Dataset to build a machine learning model for house price prediction.

Build a Credit Default Risk Prediction Model with LightGBM
In this Machine Learning Project, you will build a classification model for default prediction with LightGBM.