PyTorch DataLoader Explained - How to make a Custom Dataset
Let's play with Fourier transformers on a Custom Dataset
PyTorch Data Loading
Let me walk you through exactly how PyTorch handles data loading and show you a practical example.
How PyTorch Data Loading Works
It all starts with the Dataset class. This is where you wrap your raw data and implement two key methods:
__len__()
- tells PyTorch how many samples you have__getitem__()
- defines how to retrieve a single sample by index
The Sampler uses that length information to generate a list of indices representing your samples. You don't usually need to worry about this - the DataLoader handles it by default. The key thing is: the sampler generates indices up to the length of your dataset, which is why __len__()
matters.
Those indices get passed to your dataset's __getitem__()
method to retrieve individual samples - up to your batch size that you specified in the DataLoader.
The DataLoader then calls the collate function, which by default just stacks those individual samples into batch tensors. You can write custom collate functions if you need different behavior - maybe padding variable length sequences or filtering out bad samples.
That collate function creates the final batch tensor that gets fed to your training loop.
So the flow is: raw data → Dataset → Sampler → DataLoader → Collate → Batch → Your model.
You need to define a Dataset and a DataLoader, but you can rely on the default. Sampler and Collate function.
Walking Through a Real Example
In this notebook, I demonstrate this with two different datasets for time series forecasting.
Since I'm a time series research engineer, all our examples use electrical data patterns as our time series example.
Basic Dataset Implementation First, we create a BasicTimeSeriesDataset
that inherits from PyTorch's Dataset class. It implements those two dunder functions I mentioned:
__len__()
returns the number of samples available__getitem__()
returns one sample at a given index - in our case, 96 timesteps as input and 24 timesteps as the target to forecast
We feed this into a DataLoader with a batch size, and suddenly we have 309 batches ready for training. The batching lets us saturate our GPUs and do operations in parallel.
Custom Dataset for Inference-Time Denoising Then we get more advanced with an InferenceTimeSeriesDataset
. This one does something interesting - it applies Fourier transforms to denoise the data at inference time. The idea is that by removing noise when we want to forecast, we're giving the model cleaner, more prototypical patterns to work with.
This might be nonsense, I just wanted to try it.
The dataset includes a fourier_denoise()
function that uses Fast Fourier Transform to identify the strongest frequencies, zeros out the weaker noisy ones, then transforms it back into a time series. It's the same input structure to our model, just with some feature engineering to make it less noisy.
The Beautiful Part What's great about PyTorch is how modular everything is. Once you have your custom dataset, you just point your DataLoader to it. The training loop stays exactly the same - zero gradients, run inference, calculate loss, backward propagation, step forward. Whether you're using the basic dataset or the fancy denoising one, the rest of your code doesn't change.
At inference time, you create your denoising dataset, set denoise=True
, put the model in eval mode with model.eval()
, wrap everything in torch.no_grad()
to avoid storing gradients, and make your predictions on hopefully cleaner inputs.
The dataset abstraction means you can experiment with different data preprocessing approaches without rewriting your entire training pipeline. That's the power of PyTorch's design - each component handles one thing well, and they all work together seamlessly.