In this tutorial, you’ll learn how to use transfer learning in PyTorch by fine-tuning a pretrained CNN model (ResNet18) to perform image classification. This is ideal for beginners exploring deep learning with real-world image data.
What is Transfer Learning?
Transfer learning is a technique where you take a model pretrained on a large dataset (like ImageNet) and fine-tune it on your own smaller dataset.
It helps when:
-
You have limited labeled data.
-
You want to save time and computing power.
-
You want to leverage powerful pretrained features.
Real-World Examples
-
Classifying MRI scans (Tumor vs No Tumor)
-
Detecting whether a plant is healthy or not
-
Recognizing dog breeds with limited samples
-
Identifying fake vs real products
Why Use PyTorch for Transfer Learning?
PyTorch is a flexible, intuitive deep learning framework with dynamic computation graphs, making it perfect for both beginners and researchers.
Top reasons to use PyTorch:
-
Pythonic & user-friendly
-
Strong community support
-
Built-in pretrained models (
torchvision.models
) -
Great for debugging and experimentation
Step-by-Step Guide to Transfer Learning with ResNet18:
1. Install Dependencies:
pip install torch torchvision matplotlib
2. Import Libraries:
import torch from torch import nn from torchvision import datasets, models, transforms from torch.utils.data import DataLoader, Subset, random_split import numpy as np import matplotlib.pyplot as plt
3. Prepare the Dataset:
# Transform for resizing and normalization transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), ]) # Load CIFAR-10 dataset full_dataset = datasets.CIFAR10(root='data', train=True, download=True, transform=transform) # Only use cats (label 3) and dogs (label 5) cat_dog_indices = [i for i, (_, label) in enumerate(full_dataset) if label in [3, 5]] cat_dog_subset = Subset(full_dataset, cat_dog_indices) # Relabel: cat=0, dog=1 class BinaryCatDogDataset(torch.utils.data.Dataset): def __init__(self, subset): self.subset = subset def __getitem__(self, idx): img, label = self.subset[idx] label = 0 if label == 3 else 1 return img, label def __len__(self): return len(self.subset) binary_dataset = BinaryCatDogDataset(cat_dog_subset) # Split into train and validation sets train_size = int(0.8 * len(binary_dataset)) val_size = len(binary_dataset) - train_size train_dataset, val_dataset = random_split(binary_dataset, [train_size, val_size]) train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=32)
3. Load Pretrained ResNet18 and Fine-Tune:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load pretrained ResNet18 model = models.resnet18(pretrained=True) # Freeze all layers for param in model.parameters(): param.requires_grad = False # Replace the classifier head for binary classification model.fc = nn.Sequential( nn.Linear(model.fc.in_features, 128), nn.ReLU(), nn.Dropout(0.2), nn.Linear(128, 1), nn.Sigmoid() ) model = model.to(device)