Skip to main content
Author: Sasank Chilamkurthy In this tutorial, you will learn how to train a convolutional neural network for image classification using transfer learning. Read more about transfer learning at the CS231n notes.

Why Transfer Learning?

In practice, very few people train an entire Convolutional Network from scratch (with random initialization), because it is relatively rare to have a dataset of sufficient size. Instead, it is common to:
  1. Pretrain a ConvNet on a very large dataset (e.g., ImageNet with 1.2 million images and 1000 categories)
  2. Use the ConvNet either as an initialization or a fixed feature extractor for the task of interest

Two Major Transfer Learning Scenarios

1. Finetuning the ConvNet

Instead of random initialization, initialize the network with a pretrained network (e.g., trained on ImageNet). Rest of training proceeds as usual.
# Load pretrained model
model_ft = models.resnet18(pretrained=True)

# Replace final layer for new task
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, num_classes)

# Train all parameters
optimizer = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)

2. ConvNet as Fixed Feature Extractor

Freeze weights for all layers except the final fully connected layer. Replace the last layer with a new one with random weights and train only this layer.
# Load pretrained model
model_conv = models.resnet18(pretrained=True)

# Freeze all parameters
for param in model_conv.parameters():
    param.requires_grad = False

# Replace and train only final layer
num_ftrs = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_ftrs, num_classes)

# Only optimize final layer parameters
optimizer = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)

Example: Ants vs Bees Classification

Dataset

A small dataset with ~120 training images each for ants and bees, and 75 validation images per class. This is normally too small to generalize if trained from scratch, but transfer learning allows reasonable generalization.

Data Augmentation

data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

Training Function

def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            if phase == 'train':
                scheduler.step()

            # Track best model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

    model.load_state_dict(best_model_wts)
    return model

Results Comparison

ApproachTraining TimeBest Validation Accuracy
Finetuning all layers~37 seconds92.8%
Fixed feature extractor~23 seconds95.4%
The fixed feature extractor approach is faster (gradients not computed for most of the network) and achieves better accuracy on this small dataset.

When to Use Each Approach

ScenarioRecommended Approach
Small dataset, similar to pretrained dataFixed feature extractor
Small dataset, different from pretrained dataFinetune with small learning rate
Large dataset, similar to pretrained dataFinetune all layers
Large dataset, different from pretrained dataTrain from scratch or finetune heavily

Further Learning

References


Connect these docs to Claude, VSCode, and more via MCP for real-time answers.