In this tutorial, you’ll learn how to use transfer learning in PyTorch by fine-tuning a pretrained CNN model (ResNet18) to perform image classification. This is ideal for beginners exploring deep learning with real-world image data.
What is Transfer Learning?
Transfer learning is a technique where you take a model pretrained on a large dataset (like ImageNet) and fine-tune it on your own smaller dataset.
It helps when:
-
You have limited labeled data.
-
You want to save time and computing power.
-
You want to leverage powerful pretrained features.
Real-World Examples
-
Classifying MRI scans (Tumor vs No Tumor)
-
Detecting whether a plant is healthy or not
-
Recognizing dog breeds with limited samples
-
Identifying fake vs real products
Why Use PyTorch for Transfer Learning?
PyTorch is a flexible, intuitive deep learning framework with dynamic computation graphs, making it perfect for both beginners and researchers.
Top reasons to use PyTorch:
-
Pythonic & user-friendly
-
Strong community support
-
Built-in pretrained models (
torchvision.models) -
Great for debugging and experimentation
Step-by-Step Guide to Transfer Learning with ResNet18:
1. Install Dependencies:
pip install torch torchvision matplotlib
2. Import Libraries:
import torch from torch import nn from torchvision import datasets, models, transforms from torch.utils.data import DataLoader, Subset, random_split import numpy as np import matplotlib.pyplot as plt
3. Prepare the Dataset:
# Transform for resizing and normalization
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
# Load CIFAR-10 dataset
full_dataset = datasets.CIFAR10(root='data', train=True, download=True, transform=transform)
# Only use cats (label 3) and dogs (label 5)
cat_dog_indices = [i for i, (_, label) in enumerate(full_dataset) if label in [3, 5]]
cat_dog_subset = Subset(full_dataset, cat_dog_indices)
# Relabel: cat=0, dog=1
class BinaryCatDogDataset(torch.utils.data.Dataset):
def __init__(self, subset):
self.subset = subset
def __getitem__(self, idx):
img, label = self.subset[idx]
label = 0 if label == 3 else 1
return img, label
def __len__(self):
return len(self.subset)
binary_dataset = BinaryCatDogDataset(cat_dog_subset)
# Split into train and validation sets
train_size = int(0.8 * len(binary_dataset))
val_size = len(binary_dataset) - train_size
train_dataset, val_dataset = random_split(binary_dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)
3. Load Pretrained ResNet18 and Fine-Tune:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load pretrained ResNet18
model = models.resnet18(pretrained=True)
# Freeze all layers
for param in model.parameters():
param.requires_grad = False
# Replace the classifier head for binary classification
model.fc = nn.Sequential(
nn.Linear(model.fc.in_features, 128),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(128, 1),
nn.Sigmoid()
)
model = model.to(device)
