Fine-Tuning Pretrained Models with PyTorch Transfer Learning

In this tutorial, you’ll learn how to use transfer learning in PyTorch by fine-tuning a pretrained CNN model (ResNet18) to perform image classification. This is ideal for beginners exploring deep learning with real-world image data.

What is Transfer Learning?

Transfer learning is a technique where you take a model pretrained on a large dataset (like ImageNet) and fine-tune it on your own smaller dataset.

It helps when:

  • You have limited labeled data.

  • You want to save time and computing power.

  • You want to leverage powerful pretrained features.

Real-World Examples

  • Classifying MRI scans (Tumor vs No Tumor)

  • Detecting whether a plant is healthy or not

  • Recognizing dog breeds with limited samples

  • Identifying fake vs real products

Why Use PyTorch for Transfer Learning?

PyTorch is a flexible, intuitive deep learning framework with dynamic computation graphs, making it perfect for both beginners and researchers.

Top reasons to use PyTorch:

  1. Pythonic & user-friendly

  2. Strong community support

  3. Built-in pretrained models (torchvision.models)

  4. Great for debugging and experimentation

Step-by-Step Guide to Transfer Learning with ResNet18:

 

1. Install Dependencies:

pip install torch torchvision matplotlib

2. Import Libraries:

import torch
from torch import nn
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader, Subset, random_split
import numpy as np
import matplotlib.pyplot as plt

3. Prepare the Dataset:

# Transform for resizing and normalization
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# Load CIFAR-10 dataset
full_dataset = datasets.CIFAR10(root='data', train=True, download=True, transform=transform)

# Only use cats (label 3) and dogs (label 5)
cat_dog_indices = [i for i, (_, label) in enumerate(full_dataset) if label in [3, 5]]
cat_dog_subset = Subset(full_dataset, cat_dog_indices)

# Relabel: cat=0, dog=1
class BinaryCatDogDataset(torch.utils.data.Dataset):
    def __init__(self, subset):
        self.subset = subset

    def __getitem__(self, idx):
        img, label = self.subset[idx]
        label = 0 if label == 3 else 1
        return img, label

    def __len__(self):
        return len(self.subset)

binary_dataset = BinaryCatDogDataset(cat_dog_subset)

# Split into train and validation sets
train_size = int(0.8 * len(binary_dataset))
val_size = len(binary_dataset) - train_size
train_dataset, val_dataset = random_split(binary_dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)

3. Load Pretrained ResNet18 and Fine-Tune:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load pretrained ResNet18
model = models.resnet18(pretrained=True)

# Freeze all layers
for param in model.parameters():
    param.requires_grad = False

# Replace the classifier head for binary classification
model.fc = nn.Sequential(
    nn.Linear(model.fc.in_features, 128),
    nn.ReLU(),
    nn.Dropout(0.2),
    nn.Linear(128, 1),
    nn.Sigmoid()
)

model = model.to(device)

4. Train the Model:

criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.fc.parameters(), lr=0.001)

def train(model, loader):
    model.train()
    running_loss = 0
    for imgs, labels in loader:
        imgs, labels = imgs.to(device), labels.to(device).float().unsqueeze(1)

        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    return running_loss / len(loader)

def evaluate(model, loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for imgs, labels in loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            preds = (outputs > 0.5).squeeze().long()
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    return correct / total

5. Run Training Loop:

epochs = 5
for epoch in range(epochs):
    loss = train(model, train_loader)
    acc = evaluate(model, val_loader)
    print(f"Epoch {epoch+1}, Loss: {loss:.4f}, Val Accuracy: {acc:.4f}")

output:

Epoch 1, Loss: 0.4480, Val Accuracy: 0.8220
Epoch 2, Loss: 0.3879, Val Accuracy: 0.8310
Epoch 3, Loss: 0.3644, Val Accuracy: 0.8295
Epoch 4, Loss: 0.3624, Val Accuracy: 0.8325
Epoch 5, Loss: 0.3489, Val Accuracy: 0.8300

6. Predict on New Data:

sample_img, _ = val_dataset[0]
model.eval()
with torch.no_grad():
    pred = model(sample_img.unsqueeze(0).to(device))
    pred_class = "Dog" if pred.item() > 0.5 else "Cat"
    print(f"Predicted Class: {pred_class}")

plt.imshow(np.transpose(sample_img, (1, 2, 0)))
plt.title(f"Prediction: {pred_class}")
plt.axis('off')
plt.show()

output:

Internal Resources:

PyTorch website

Leave a Comment

Your email address will not be published. Required fields are marked *

Scroll to Top