Skip to main content
Open In Colab

ROI Head: ROI Align + Classification and Box Regression

Notebook 4 of 6 in the Faster RCNN from-scratch series Given proposals from the RPN, we extract fixed-size features via ROI Align, then classify each proposal and refine its bounding box. Mask RCNN extension point: this notebook also demonstrates the 14×14 ROI Align variant used by the mask head (notebook 07).
import sys, os, pathlib
# Locate frcnn_common.py — works whether run via papermill or interactively
_nb_candidates = [
    pathlib.Path.cwd().parent,  # interactive: cwd is the notebook dir
    pathlib.Path.cwd() / 'notebooks' / 'scene-understanding' / 'object-detection' / 'faster-rcnn' / 'pytorch',  # papermill: cwd is repo root
]
for _p in _nb_candidates:
    if (_p / 'frcnn_common.py').exists():
        sys.path.insert(0, str(_p))
        break

import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from typing import List, Tuple

from frcnn_common import (
    ROIAlign, TwoMLPHead, FastRCNNPredictor,
    Bottleneck, ResNet50, FPN,
    AnchorGenerator, RPNHead, RegionProposalNetwork,
    IMG_SIZE, DEVICE,
)

print(f"Device: {DEVICE}")
/workspaces/eng-ai-agents/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
Device: cuda
# ROIAlign is imported from frcnn_common — inspect its structure:
roi_align = ROIAlign(out_size=7)
print(f"ROI output size: {roi_align.out_size}")
print(f"FPN level params: k0={roi_align.k0}, k_min={roi_align.k_min}, k_max={roi_align.k_max}")
ROI output size: 7
FPN level params: k0=4, k_min=2, k_max=5
# TwoMLPHead and FastRCNNPredictor are imported from frcnn_common:
mlp_head = TwoMLPHead()
predictor = FastRCNNPredictor()
print("TwoMLPHead:", mlp_head)
print("FastRCNNPredictor:", predictor)
TwoMLPHead: TwoMLPHead(
  (fc1): Linear(in_features=12544, out_features=1024, bias=True)
  (fc2): Linear(in_features=1024, out_features=1024, bias=True)
)
FastRCNNPredictor: FastRCNNPredictor(
  (cls): Linear(in_features=1024, out_features=81, bias=True)
  (box): Linear(in_features=1024, out_features=324, bias=True)
)
# Smoke test with dummy feature maps and proposals
roi_align = ROIAlign(out_size=7)
mlp_head  = TwoMLPHead()
predictor = FastRCNNPredictor()

feat_maps = [
    torch.randn(1, 256, 100, 100),
    torch.randn(1, 256,  50,  50),
    torch.randn(1, 256,  25,  25),
    torch.randn(1, 256,  13,  13),
]
proposals = [torch.tensor([[50, 50, 200, 200], [100, 100, 300, 300], [200, 200, 400, 400]], dtype=torch.float32)]

roi_feats = roi_align(feat_maps, proposals, (800, 800))
box_feats = mlp_head(roi_feats)
cls_logits, bbox_preds = predictor(box_feats)

print(f"ROI features : {roi_feats.shape}")
print(f"Box features : {box_feats.shape}")
print(f"Class logits : {cls_logits.shape}")
print(f"Box deltas   : {bbox_preds.shape}")
ROI features : torch.Size([3, 256, 7, 7])
Box features : torch.Size([3, 1024])
Class logits : torch.Size([3, 81])
Box deltas   : torch.Size([3, 324])
# Inspection: mean-channel activation of 7x7 ROI crops
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
for i, ax in enumerate(axes):
    crop = roi_feats[i].mean(dim=0).detach().numpy()
    im = ax.imshow(crop, cmap='viridis')
    ax.set_title(f'ROI {i} — 7×7 (mean over 256 ch)')
    ax.axis('off')
    plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
plt.suptitle('ROI Align 7×7 Crops (dummy feature maps)')
plt.tight_layout()
plt.savefig('images/roi_crops.png', dpi=100, bbox_inches='tight')
plt.show()
Output from cell 5
# Mask RCNN extension point: 14x14 ROI Align
mask_roi_align = ROIAlign(out_size=14)
mask_roi_feats = mask_roi_align(feat_maps, proposals, (800, 800))
print(f"Mask ROI features (14x14): {mask_roi_feats.shape}")
# Expected: [3, 256, 14, 14]
print("Extension point ready for Mask RCNN mask head (notebook 07).")
Mask ROI features (14x14): torch.Size([3, 256, 14, 14])
Extension point ready for Mask RCNN mask head (notebook 07).
Key references: (Redmon et al., 2015; -Scratch-Vision-Trans, n.d.; Zagoruyko & Komodakis, 2016; Wightman et al., 2021; Ren et al., 2015)

References

  • Redmon, J., Divvala, S., Girshick, R., Farhadi, A. (2015). You only look once: Unified, real-time object detection.
  • Ren, S., He, K., Girshick, R., Sun, J. (2015). Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks.
  • (n.d.). PyTorch-Scratch-Vision-Transformer-ViT: Simple and easy to understand PyTorch implementation of Vision Transformer (ViT) from scratch, with detailed steps. Tested on common datasets like MNIST, CIFAR10, and more.
  • Wightman, R., Touvron, H., Jégou, H. (2021). ResNet strikes back: An improved training procedure in timm.
  • Zagoruyko, S., Komodakis, N. (2016). Wide Residual Networks.