How to split a dataset using pytorch

This recipe helps you split a dataset using pytorch

Recipe Objective

How to split a dataset using pytorch?

This is achieved by using the "random_split" function, the function is used to split a dataset into more than one sub datasets, it is also used to create train and test datasets.

PyTorch vs Tensorflow - Which One Should You Choose For Your Next Deep Learning Project ?

Step 1 - Import library

import pprint as pp
from sklearn import datasets
import numpy as np
import torch from torch.utils.data
import Dataset from torch.utils.data
import random_split

Step 2 - Take Sample data

samples = 2000
X_data, Y_data = datasets.make_blobs(n_samples= samples, n_features=4, centers=[(0,5),(4,0)], random_state=0)

Step 3 - Create Dataset Class

class CreateDataset(Dataset):
   def __init__(self, x, y):
      self.x = X_data
      self.y = Y_data
   def __getitem__(self, index):
      sample = { 'feature': torch.tensor([self.x[index]], dtype=torch.float32),
                  'label': torch.tensor([self.y[index]], dtype=torch.long)}
      return sample
   def __len__(self):
      return len(self.x)

Step 4 - Create dataset and check length of it

torch_dataset = CreateDataset(X_data, Y_data)
print("length of the dataset is:", len(torch_dataset))

length of the dataset is: 2000

Step 5 - Split the dataset

train_data, test_data = random_split(torch_dataset, [1400, 600])
print("The length of train data is:",len(train_data))
print("The length of test data is:",len(test_data))

The length of train data is: 1400
The length of test data is: 600

What Users are saying..

profile image

Anand Kumpatla

Sr Data Scientist @ Doubleslash Software Solutions Pvt Ltd
linkedin profile url

ProjectPro is a unique platform and helps many people in the industry to solve real-life problems with a step-by-step walkthrough of projects. A platform with some fantastic resources to gain... Read More

Relevant Projects

Learn to Build a Neural network from Scratch using NumPy
In this deep learning project, you will learn to build a neural network from scratch using NumPy

Insurance Pricing Forecast Using XGBoost Regressor
In this project, we are going to talk about insurance forecast by using linear and xgboost regression techniques.

CycleGAN Implementation for Image-To-Image Translation
In this GAN Deep Learning Project, you will learn how to build an image to image translation model in PyTorch with Cycle GAN.

Learn to Build a Siamese Neural Network for Image Similarity
In this Deep Learning Project, you will learn how to build a siamese neural network with Keras and Tensorflow for Image Similarity.

Llama2 Project for MetaData Generation using FAISS and RAGs
In this LLM Llama2 Project, you will automate metadata generation using Llama2, RAGs, and AWS to reduce manual efforts.

Model Deployment on GCP using Streamlit for Resume Parsing
Perform model deployment on GCP for resume parsing model using Streamlit App.

Deploying Machine Learning Models with Flask for Beginners
In this MLOps on GCP project you will learn to deploy a sales forecasting ML Model using Flask.

GCP MLOps Project to Deploy ARIMA Model using uWSGI Flask
Build an end-to-end MLOps Pipeline to deploy a Time Series ARIMA Model on GCP using uWSGI and Flask

House Price Prediction Project using Machine Learning in Python
Use the Zillow Zestimate Dataset to build a machine learning model for house price prediction.

Build Time Series Models for Gaussian Processes in Python
Time Series Project - A hands-on approach to Gaussian Processes for Time Series Modelling in Python