Author: Sasank Chilamkurthy
In this tutorial, you will learn how to train a convolutional neural network for image classification using transfer learning. Read more about transfer learning at the CS231n notes.
Why Transfer Learning?
In practice, very few people train an entire Convolutional Network from scratch (with random initialization), because it is relatively rare to have a dataset of sufficient size. Instead, it is common to:
- Pretrain a ConvNet on a very large dataset (e.g., ImageNet with 1.2 million images and 1000 categories)
- Use the ConvNet either as an initialization or a fixed feature extractor for the task of interest
Two Major Transfer Learning Scenarios
1. Finetuning the ConvNet
Instead of random initialization, initialize the network with a pretrained network (e.g., trained on ImageNet). Rest of training proceeds as usual.
# Load pretrained model
model_ft = models.resnet18(pretrained=True)
# Replace final layer for new task
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, num_classes)
# Train all parameters
optimizer = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
Freeze weights for all layers except the final fully connected layer. Replace the last layer with a new one with random weights and train only this layer.
# Load pretrained model
model_conv = models.resnet18(pretrained=True)
# Freeze all parameters
for param in model_conv.parameters():
param.requires_grad = False
# Replace and train only final layer
num_ftrs = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_ftrs, num_classes)
# Only optimize final layer parameters
optimizer = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)
Example: Ants vs Bees Classification
Dataset
A small dataset with ~120 training images each for ants and bees, and 75 validation images per class. This is normally too small to generalize if trained from scratch, but transfer learning allows reasonable generalization.
Data Augmentation
data_transforms = {
'train': transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
Training Function
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
best_model_wts = copy.deepcopy(model.state_dict())
best_acc = 0.0
for epoch in range(num_epochs):
for phase in ['train', 'val']:
if phase == 'train':
model.train()
else:
model.eval()
running_loss = 0.0
running_corrects = 0
for inputs, labels in dataloaders[phase]:
inputs = inputs.to(device)
labels = labels.to(device)
optimizer.zero_grad()
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
if phase == 'train':
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
if phase == 'train':
scheduler.step()
# Track best model
if phase == 'val' and epoch_acc > best_acc:
best_acc = epoch_acc
best_model_wts = copy.deepcopy(model.state_dict())
model.load_state_dict(best_model_wts)
return model
Results Comparison
| Approach | Training Time | Best Validation Accuracy |
|---|
| Finetuning all layers | ~37 seconds | 92.8% |
| Fixed feature extractor | ~23 seconds | 95.4% |
The fixed feature extractor approach is faster (gradients not computed for most of the network) and achieves better accuracy on this small dataset.
When to Use Each Approach
| Scenario | Recommended Approach |
|---|
| Small dataset, similar to pretrained data | Fixed feature extractor |
| Small dataset, different from pretrained data | Finetune with small learning rate |
| Large dataset, similar to pretrained data | Finetune all layers |
| Large dataset, different from pretrained data | Train from scratch or finetune heavily |
Further Learning
References