K Means Clustering with scikit learn

The k means clustering problem is solved using either Lloyd or Elkan algorithm. The k means algorithm is very fast, but it falls in local minima. That’s why it can be useful to restart it several times.

Recipe Objective - K-Means Clustering with the scikit-learn?

The k-means clustering problem is solved using either Lloyd’s or Elkan’s algorithm. The k-means algorithm is very fast (one of the fastest clustering algorithms available), but it falls in local minima. That’s why it can be useful to restart it several times.

The average complexity is given by O(k n T), where n is the number of samples and T is the number of iterations.

The k-means is very good at identifying clusters with a spherical shape, one of the drawbacks of this clustering algorithm is that we have to specify the number of clusters.

Unlock The Top Reasons to Pursue a Machine Learning Career Now !

Links for the more related projects:-

https://www.projectpro.io/projects/data-science-projects/deep-learning-projects
https://www.projectpro.io/projects/data-science-projects/neural-network-projects

Example:-

Importing libraries and creating dataset:-

import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.datasets import make_blobs

# create dataset
X, y = make_blobs(n_samples=150, n_features=2,centers=3, cluster_std=0.5,shuffle=True, random_state=0)

# plot
plt.scatter(X[:, 0], X[:, 1],c='white', marker='o',edgecolor='black', s=50)
plt.show()

kmeans clustering using sklearn:-

km = KMeans(n_clusters=3, init='random', n_init=10, max_iter=300, tol=1e-04, random_state=0)
y_km = km.fit_predict(X)

Plotting clusters:-

# plot the 3 clusters
plt.scatter(X[y_km == 0, 0], X[y_km == 0, 1],s=50, c='lightgreen',marker='s', edgecolor='black',label='cluster 1')

plt.scatter(X[y_km == 1, 0], X[y_km == 1, 1],s=50, c='orange',marker='o', edgecolor='black',label='cluster 2')

plt.scatter(X[y_km == 2, 0], X[y_km == 2, 1],s=50, c='lightblue',marker='v', edgecolor='black',label='cluster 3')

# plot the centroids
plt.scatter(km.cluster_centers_[:, 0], km.cluster_centers_[:, 1],s=250, marker='*',c='red', edgecolor='black',label='centroids')

plt.legend(scatterpoints=1)
plt.grid()
plt.show()

Calculating distortion:-

# calculate distortion for a range of number of cluster
distortions = []
for i in range(1, 11):
    km = KMeans(n_clusters=i, init='random',n_init=10, max_iter=300,tol=1e-04, random_state=0)
    km.fit(X)
    distortions.append(km.inertia_)

# plot
plt.plot(range(1, 11), distortions, marker='o')
plt.xlabel('Number of clusters')
plt.ylabel('Distortion')
plt.show()

What Users are saying..

profile image

Jingwei Li

Graduate Research assistance at Stony Brook University
linkedin profile url

ProjectPro is an awesome platform that helps me learn much hands-on industrial experience with a step-by-step walkthrough of projects. There are two primary paths to learn: Data Science and Big Data.... Read More

Relevant Projects

Loan Eligibility Prediction in Python using H2O.ai
In this loan prediction project you will build predictive models in Python using H2O.ai to predict if an applicant is able to repay the loan or not.

Personalized Medicine: Redefining Cancer Treatment
In this Personalized Medicine Machine Learning Project you will learn to classify genetic mutations on the basis of medical literature into 9 classes.

Classification Projects on Machine Learning for Beginners - 1
Classification ML Project for Beginners - A Hands-On Approach to Implementing Different Types of Classification Algorithms in Machine Learning for Predictive Modelling

Build a Multi-Class Classification Model in Python on Saturn Cloud
In this machine learning classification project, you will build a multi-class classification model in Python on Saturn Cloud to predict the license status of a business.

Time Series Forecasting with LSTM Neural Network Python
Deep Learning Project- Learn to apply deep learning paradigm to forecast univariate time series data.

Digit Recognition using CNN for MNIST Dataset in Python
In this deep learning project, you will build a convolutional neural network using MNIST dataset for handwritten digit recognition.

Hands-On Approach to Regression Discontinuity Design Python
In this machine learning project, you will learn to implement Regression Discontinuity Design Example in Python to determine the effect of age on Mortality Rate in Python.

Build a Hybrid Recommender System in Python using LightFM
In this Recommender System project, you will build a hybrid recommender system in Python using LightFM .

Build a Text Generator Model using Amazon SageMaker
In this Deep Learning Project, you will train a Text Generator Model on Amazon Reviews Dataset using LSTM Algorithm in PyTorch and deploy it on Amazon SageMaker.

Llama2 Project for MetaData Generation using FAISS and RAGs
In this LLM Llama2 Project, you will automate metadata generation using Llama2, RAGs, and AWS to reduce manual efforts.