How to do cost complexity pruning in decision tree classifier?
MACHINE LEARNING RECIPES DATA CLEANING PYTHON DATA MUNGING PANDAS CHEATSHEET     ALL TAGS

How to do cost complexity pruning in decision tree classifier?

How to do cost complexity pruning in decision tree classifier?

This recipe helps you do cost complexity pruning in decision tree classifier

0

Recipe Objective

How to do cost complexity pruning in decision tree regressor

Pruning is the technique used to reduce the problem of overfitting. In pruning, we cut down the selected parts of the tree such as branches, buds, roots to improve the tree structure and promote healthy growth.

Step 1- Importing Libraries.

import pandas as pd import numpy as np from sklearn.tree import DecisionTreeClassifier from sklearn.model_selection import train_test_split import matplotlib.pyplot as plt import seaborn as sns from sklearn.metrics import accuracy_score

Step 2- Importing dataset

iris = sns.load_dataset('iris') iris.head()

Step 3- Preparing the dataset.

X=iris.drop(columns='species') y=iris['species'] Xtrain, Xtest, ytrain, ytest= train_test_split(X,y, test_size=0.3, random_state=20)

Step 4- Fitting model to Decision Tree Classifier.

Running the model through Decision Tree Classifier and checking accuracy.

tree= DecisionTreeClassifier() tree.fit(Xtrain, ytrain) ytrain_pred=tree.predict(Xtrain) ytest_pred=tree.predict(Xtest) print(accuracy_score(ytrain,ytrain_pred),accuracy_score(ytest,ytest_pred))

Step 5- Applying Pruning.

prun = tree.cost_complexity_pruning_path(Xtrain,ytrain) alphas=prun['ccp_alphas'] alphas

Step 6-Pruning the complete dataset.

Applying pruning to the complete dataset and visualizing the whole process.

train_accuracy, test_accuracy=[],[] for j in alphas: tree= DecisionTreeClassifier(ccp_alpha=j) tree.fit(Xtrain,ytrain) ytrain_pred=tree.predict(Xtrain) ytest_pred=tree.predict(Xtest) train_accuracy.append(accuracy_score(ytrain, ytrain_pred)) test_accuracy.append(accuracy_score(ytest, ytest_pred)) sns.set() plt.figure(figsize=(10,6)) sns.lineplot(y=train_accuracy, x=alphas, label='Training Accuracy') sns.lineplot(y=test_accuracy, x=alphas, label='Testing Accuracy') plt.show()

Relevant Projects

Data Science Project - Instacart Market Basket Analysis
Data Science Project - Build a recommendation engine which will predict the products to be purchased by an Instacart consumer again.

Loan Eligibility Prediction using Gradient Boosting Classifier
This data science in python project predicts if a loan should be given to an applicant or not. We predict if the customer is eligible for loan based on several factors like credit score and past history.

Music Recommendation System Project using Python and R
Machine Learning Project - Work with KKBOX's Music Recommendation System dataset to build the best music recommendation engine.

Identifying Product Bundles from Sales Data Using R Language
In this data science project in R, we are going to talk about subjective segmentation which is a clustering technique to find out product bundles in sales data.

Credit Card Fraud Detection as a Classification Problem
In this data science project, we will predict the credit card fraud in the transactional dataset using some of the predictive models.

Sequence Classification with LSTM RNN in Python with Keras
In this project, we are going to work on Sequence to Sequence Prediction using IMDB Movie Review Dataset​ using Keras in Python.

Time Series Forecasting with LSTM Neural Network Python
Deep Learning Project- Learn to apply deep learning paradigm to forecast univariate time series data.

Human Activity Recognition Using Smartphones Data Set
In this deep learning project, you will build a classification system where to precisely identify human fitness activities.

PySpark Tutorial - Learn to use Apache Spark with Python
PySpark Project-Get a handle on using Python with Spark through this hands-on data processing spark python tutorial.

Data Science Project in Python on BigMart Sales Prediction
The goal of this data science project is to build a predictive model and find out the sales of each product at a given Big Mart store.