How to implement Naive Bayes classification in R

In this recipe, we shall learn how to implement a supervised learning algorithm - the Naive Bayes Classification algorithm with the help of an example in R.

Recipe Objective: How to implement Naive Bayes classification in R?

The naive Bayes Classification algorithm is a supervised learning algorithm and is based on the Bayes theorem. It comprises of two words -
Naive: It assumes that the occurrence of a specific feature is independent of the occurrence of other features.
Bayes: It is based on the Bayes theorem. Steps to implement Naive Bayes Classification in R are as follows-

Access Text Classification using Naive Bayes Python Code

Step 1: Import required libraries

library(e1071)
library(gmodels)
library(dplyr)

Step 2: Load the data set.

We will make use of the iris dataframe. Iris is an inbuilt data frame that gives the measurements in centimeters of the variables sepal length and width and petal length and width, respectively, for 50 flowers from each of 3 species of iris. The species are Iris setosa, versicolor, and virginica.

data("iris")

#displays first 6 rows of the dataset
head(iris)

   Sepal.Length Sepal.Width Petal.Length Petal.Width Species
 1          5.1         3.5          1.4         0.2  setosa
 2          4.9         3.0          1.4         0.2  setosa
 3          4.7         3.2          1.3         0.2  setosa
 4          4.6         3.1          1.5         0.2  setosa
 5          5.0         3.6          1.4         0.2  setosa
 6          5.4         3.9          1.7         0.4  setosa

Step 3: Check the structure of the dataset

str(iris)

 'data.frame':    150 obs. of  5 variables:
  $ Sepal.Length: num  5.1 4.9 4.7 4.6 5 5.4 4.6 5 4.4 4.9 ...
  $ Sepal.Width : num  3.5 3 3.2 3.1 3.6 3.9 3.4 3.4 2.9 3.1 ...
  $ Petal.Length: num  1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 ...
  $ Petal.Width : num  0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ...
  $ Species     : Factor w/ 3 levels "setosa","versicolor",..: 1 1 1 1 1 1 1 1 1 1 ...

All four independent variables are of numeric types, and our dependent or predictor variable is a factor with three levels(3 species).

Step 4: Checking the summary

Summary of the dataset gives us minimum, 1st quartile, median, mean, 3rd quartile, maximum values of all the numeric columns.

summary(iris)

   Sepal.Length    Sepal.Width     Petal.Length    Petal.Width   
  Min.   :4.300   Min.   :2.000   Min.   :1.000   Min.   :0.100  
  1st Qu.:5.100   1st Qu.:2.800   1st Qu.:1.600   1st Qu.:0.300  
  Median :5.800   Median :3.000   Median :4.350   Median :1.300  
  Mean   :5.843   Mean   :3.057   Mean   :3.758   Mean   :1.199  
  3rd Qu.:6.400   3rd Qu.:3.300   3rd Qu.:5.100   3rd Qu.:1.800  
  Max.   :7.900   Max.   :4.400   Max.   :6.900   Max.   :2.500  
        Species  
  setosa    :50  
  versicolor:50  
  virginica :50  

Step 5: Train - Test Split

#We split the data into 2 parts. First part consisting 90% of the data will be used as training set and the later 10% will be used as testing set.
index = sample(2,nrow(iris),prob = c(0.9,0.1),replace=TRUE)

#setting seed ensures that you get the same result if you start with that same seed each time you run the same process
set.seed(1234)

#training set
train = iris[index==1,]

#testing set
test = iris[index==2,]

Step 6: Separate the test labels from the test data

#test_data will be given as an input to the model to predict species
test_data = test[1:4]

#test_labels are the actual values of species of the test data
test_label = test[,5]

Step 7: Train the model

model=naiveBayes(train$Species~.,train)
model

 Naive Bayes Classifier for Discrete Predictors
 
 Call:
 naiveBayes.default(x = X, y = Y, laplace = laplace)
 
 A-priori probabilities:
 Y
     setosa versicolor  virginica 
  0.3260870  0.3405797  0.3333333 
 
 Conditional probabilities:
             Sepal.Length
 Y                [,1]      [,2]
   setosa     5.000000 0.3496752
   versicolor 5.921277 0.4995373
   virginica  6.571739 0.6414092
 
             Sepal.Width
 Y                [,1]      [,2]
   setosa     3.411111 0.3600645
   versicolor 2.761702 0.3193425
   virginica  2.976087 0.3301076
 
             Petal.Length
 Y                [,1]      [,2]
   setosa     1.462222 0.1812694
   versicolor 4.257447 0.4790138
   virginica  5.523913 0.5363186
 
             Petal.Width
 Y                 [,1]      [,2]
   setosa     0.2422222 0.1076376
   versicolor 1.3234043 0.2034622
   virginica  2.0152174 0.2772623

Step 8: Make predictions

test_result=predict(model,test_data)
test_result

  [1] setosa     setosa     setosa     setosa     setosa     versicolor
  [7] versicolor versicolor virginica  virginica  virginica  virginica 
 Levels: setosa versicolor virginica

Step 9: Compare the predicted and actual values

CrossTable(x=test_label, y=test_result)

 
    Cell Contents
 |-------------------------|
 |                       N |
 | Chi-square contribution |
 |           N / Row Total |
 |           N / Col Total |
 |         N / Table Total |
 |-------------------------|
 
  
 Total Observations in Table:  12 
 
  
              | test_result 
   test_label |     setosa | versicolor |  virginica |  Row Total | 
 -------------|------------|------------|------------|------------|
       setosa |          5 |          0 |          0 |          5 | 
              |      4.083 |      1.250 |      1.667 |            | 
              |      1.000 |      0.000 |      0.000 |      0.417 | 
              |      1.000 |      0.000 |      0.000 |            | 
              |      0.417 |      0.000 |      0.000 |            | 
 -------------|------------|------------|------------|------------|
   versicolor |          0 |          3 |          0 |          3 | 
              |      1.250 |      6.750 |      1.000 |            | 
              |      0.000 |      1.000 |      0.000 |      0.250 | 
              |      0.000 |      1.000 |      0.000 |            | 
              |      0.000 |      0.250 |      0.000 |            | 
 -------------|------------|------------|------------|------------|
    virginica |          0 |          0 |          4 |          4 | 
              |      1.667 |      1.000 |      5.333 |            | 
              |      0.000 |      0.000 |      1.000 |      0.333 | 
              |      0.000 |      0.000 |      1.000 |            | 
              |      0.000 |      0.000 |      0.333 |            | 
 -------------|------------|------------|------------|------------|
 Column Total |          5 |          3 |          4 |         12 | 
              |      0.417 |      0.250 |      0.333 |            | 
 -------------|------------|------------|------------|------------|
 

We can see that our model has successfully predicted all rows belonging to setosa and virginica correctly. It has, however, misclassified a row belonging to Versicolor as virginica. The accuracy of the model is 93.75%

What Users are saying..

profile image

Ed Godalle

Director Data Analytics at EY / EY Tech
linkedin profile url

I am the Director of Data Analytics with over 10+ years of IT experience. I have a background in SQL, Python, and Big Data working with Accenture, IBM, and Infosys. I am looking to enhance my skills... Read More

Relevant Projects

PyTorch Project to Build a LSTM Text Classification Model
In this PyTorch Project you will learn how to build an LSTM Text Classification model for Classifying the Reviews of an App .

Hands-On Approach to Regression Discontinuity Design Python
In this machine learning project, you will learn to implement Regression Discontinuity Design Example in Python to determine the effect of age on Mortality Rate in Python.

Detectron2 Object Detection and Segmentation Example Python
Object Detection using Detectron2 - Build a Dectectron2 model to detect the zones and inhibitions in antibiogram images.

Deep Learning Project for Time Series Forecasting in Python
Deep Learning for Time Series Forecasting in Python -A Hands-On Approach to Build Deep Learning Models (MLP, CNN, LSTM, and a Hybrid Model CNN-LSTM) on Time Series Data.

Build a Face Recognition System in Python using FaceNet
In this deep learning project, you will build your own face recognition system in Python using OpenCV and FaceNet by extracting features from an image of a person's face.

Recommender System Machine Learning Project for Beginners-1
Recommender System Machine Learning Project for Beginners - Learn how to design, implement and train a rule-based recommender system in Python

Digit Recognition using CNN for MNIST Dataset in Python
In this deep learning project, you will build a convolutional neural network using MNIST dataset for handwritten digit recognition.

Personalized Medicine: Redefining Cancer Treatment
In this Personalized Medicine Machine Learning Project you will learn to classify genetic mutations on the basis of medical literature into 9 classes.

AWS MLOps Project for Gaussian Process Time Series Modeling
MLOps Project to Build and Deploy a Gaussian Process Time Series Model in Python on AWS

Avocado Machine Learning Project Python for Price Prediction
In this ML Project, you will use the Avocado dataset to build a machine learning model to predict the average price of avocado which is continuous in nature based on region and varieties of avocado.