ROI Head: ROI Align + Classification and Box Regression
Notebook 4 of 6 in the Faster RCNN from-scratch series Given proposals from the RPN, we extract fixed-size features via ROI Align, then classify each proposal and refine its bounding box. Mask RCNN extension point: this notebook also demonstrates the 14×14 ROI Align variant used by the mask head (notebook 07).Copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from typing import List, Tuple
Copy
class ROIAlign(nn.Module):
"""ROI Align using bilinear interpolation via F.grid_sample.
Extracts a fixed (out_size x out_size) feature crop for each proposal,
selecting the FPN level based on box area (Lin et al. FPN level assignment):
k = clip(k0 + floor(log2(sqrt(wh) / 224)), k_min, k_max)
k0=4, k_min=2 (P2), k_max=5 (P5) -> index 0..3
"""
def __init__(self, out_size: int = 7,
k0: int = 4, k_min: int = 2, k_max: int = 5):
super().__init__()
self.out_size = out_size
self.k0 = k0
self.k_min = k_min
self.k_max = k_max
def _assign_level(self, boxes: torch.Tensor) -> torch.Tensor:
"""Return 0-indexed FPN level (0=P2, 1=P3, 2=P4, 3=P5) per box."""
ws = boxes[:, 2] - boxes[:, 0]
hs = boxes[:, 3] - boxes[:, 1]
areas = (ws * hs).clamp(min=1e-6).sqrt()
levels = torch.floor(self.k0 + torch.log2(areas / 224.0)).long()
return levels.clamp(self.k_min, self.k_max) - self.k_min
def forward(self, feature_maps: List[torch.Tensor],
proposals: List[torch.Tensor],
image_size: Tuple[int, int]) -> torch.Tensor:
"""
Args:
feature_maps: [P2, P3, P4, P5] — (B, 256, H_i, W_i) each
proposals: list of (N_i, 4) per image, pixel coords
image_size: (H, W)
Returns:
roi_features: (sum(N_i), 256, out_size, out_size)
"""
H, W = image_size
strides = [4, 8, 16, 32]
all_features = []
for batch_idx, props in enumerate(proposals):
if len(props) == 0:
continue
levels = self._assign_level(props)
feats = torch.zeros(len(props), feature_maps[0].shape[1],
self.out_size, self.out_size,
device=props.device)
for lvl, (fm, stride) in enumerate(zip(feature_maps, strides)):
mask = levels == lvl
if not mask.any():
continue
lvl_props = props[mask]
n = len(lvl_props)
# Normalise box coords to [-1, 1] for grid_sample
x1 = lvl_props[:, 0] / W * 2 - 1
y1 = lvl_props[:, 1] / H * 2 - 1
x2 = lvl_props[:, 2] / W * 2 - 1
y2 = lvl_props[:, 3] / H * 2 - 1
gx = torch.linspace(0, 1, self.out_size, device=props.device)
gy = torch.linspace(0, 1, self.out_size, device=props.device)
gy_g, gx_g = torch.meshgrid(gy, gx, indexing='ij')
gx_g = x1[:, None, None] + (x2 - x1)[:, None, None] * gx_g[None]
gy_g = y1[:, None, None] + (y2 - y1)[:, None, None] * gy_g[None]
grid = torch.stack([gx_g, gy_g], dim=-1)
fm_exp = fm[batch_idx:batch_idx + 1].expand(n, -1, -1, -1)
crops = F.grid_sample(fm_exp, grid, align_corners=True,
mode='bilinear', padding_mode='border')
feats[mask] = crops
all_features.append(feats)
return torch.cat(all_features, dim=0)
Copy
class TwoMLPHead(nn.Module):
"""Two fully-connected layers applied after ROI Align."""
def __init__(self, in_channels: int = 256 * 7 * 7, fc_dim: int = 1024):
super().__init__()
self.fc1 = nn.Linear(in_channels, fc_dim)
self.fc2 = nn.Linear(fc_dim, fc_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.flatten(1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return x
class FastRCNNPredictor(nn.Module):
"""Sibling FC heads: class scores and per-class box deltas."""
def __init__(self, in_channels: int = 1024, num_classes: int = 81):
super().__init__()
self.cls_score = nn.Linear(in_channels, num_classes)
self.bbox_pred = nn.Linear(in_channels, num_classes * 4)
nn.init.normal_(self.cls_score.weight, std=0.01)
nn.init.normal_(self.bbox_pred.weight, std=0.001)
nn.init.zeros_(self.cls_score.bias)
nn.init.zeros_(self.bbox_pred.bias)
def forward(self, x: torch.Tensor):
return self.cls_score(x), self.bbox_pred(x)
Copy
# Smoke test with dummy feature maps and proposals
roi_align = ROIAlign(out_size=7)
mlp_head = TwoMLPHead()
predictor = FastRCNNPredictor()
feat_maps = [
torch.randn(1, 256, 200, 200), # P2
torch.randn(1, 256, 100, 100), # P3
torch.randn(1, 256, 50, 50), # P4
torch.randn(1, 256, 25, 25), # P5
]
proposals = [torch.tensor([
[ 50., 50., 300., 300.],
[100., 100., 400., 400.],
[200., 200., 600., 600.],
])]
roi_feats = roi_align(feat_maps, proposals, (800, 800))
box_feats = mlp_head(roi_feats)
cls_logits, bbox_preds = predictor(box_feats)
print(f"ROI features: {roi_feats.shape}") # [3, 256, 7, 7]
print(f"Box features: {box_feats.shape}") # [3, 1024]
print(f"Class logits: {cls_logits.shape}") # [3, 81]
print(f"Box preds: {bbox_preds.shape}") # [3, 324]
Copy
ROI features: torch.Size([3, 256, 7, 7])
Box features: torch.Size([3, 1024])
Class logits: torch.Size([3, 81])
Box preds: torch.Size([3, 324])
Copy
# Inspection: mean-channel activation of 7x7 ROI crops
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
for i, ax in enumerate(axes):
crop = roi_feats[i].mean(dim=0).detach().numpy()
im = ax.imshow(crop, cmap='viridis')
ax.set_title(f'ROI {i} — 7×7 (mean over 256 ch)')
ax.axis('off')
plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
plt.suptitle('ROI Align 7×7 Crops (dummy feature maps)')
plt.tight_layout()
plt.savefig('images/roi_crops.png', dpi=100, bbox_inches='tight')
plt.show()

Copy
# Mask RCNN extension point: 14x14 ROI Align
mask_roi_align = ROIAlign(out_size=14)
mask_roi_feats = mask_roi_align(feat_maps, proposals, (800, 800))
print(f"Mask ROI features (14x14): {mask_roi_feats.shape}")
# Expected: [3, 256, 14, 14]
print("Extension point ready for Mask RCNN mask head (notebook 07).")
Copy
Mask ROI features (14x14): torch.Size([3, 256, 14, 14])
Extension point ready for Mask RCNN mask head (notebook 07).

