Skip to main content
Open In Colab

Feature Aggregation and Anchor-Free Detection Head

Notebook 3 of 5 in the YOLOv11 From-Scratch Series In this notebook we build the neck and detection head of YOLOv11. The backbone (Notebook 2) produces multi-scale feature maps P3, P4, and P5, but these features are not yet ready for detection:
  • Deep features (P5) have strong semantics but poor spatial resolution.
  • Shallow features (P3) have fine spatial detail but weak semantics.
The neck bridges this gap through bidirectional feature fusion:
  1. FPN (Feature Pyramid Network) --- top-down pathway that propagates high-level semantic information to lower-level features.
  2. PAN (Path Aggregation Network) --- bottom-up pathway that propagates strong localization signals back up.
  3. C2PSA (Channel Attention) --- lightweight partial self-attention for feature refinement.
On top of the fused features, a decoupled anchor-free detection head independently predicts:
  • Classification logits --- probability distribution over object classes.
  • Box regression offsets --- encoded via Distribution Focal Loss (DFL) for precise localization.
By the end of this notebook, you will have a complete, forward-passable YOLOv11 model.

Imports

import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from typing import List, Tuple

Backbone building blocks (from Notebook 2)

The following cells re-define the backbone building blocks introduced in Notebook 2. They are reproduced here in compact form so that this notebook is fully self-contained. Refer to Notebook 2 for detailed explanations of each module.
class ConvBNSiLU(nn.Module):
    def __init__(self, in_c, out_c, k=1, s=1, p=None, g=1):
        super().__init__()
        if p is None:
            p = k // 2
        self.conv = nn.Conv2d(in_c, out_c, k, s, p, groups=g, bias=False)
        self.bn = nn.BatchNorm2d(out_c)
        self.act = nn.SiLU(inplace=True)

    def forward(self, x):
        return self.act(self.bn(self.conv(x)))


class Bottleneck(nn.Module):
    def __init__(self, in_c, out_c, shortcut=True, e=0.5):
        super().__init__()
        mid = int(out_c * e)
        self.cv1 = ConvBNSiLU(in_c, mid, 3)
        self.cv2 = ConvBNSiLU(mid, out_c, 3)
        self.add = shortcut and in_c == out_c

    def forward(self, x):
        y = self.cv2(self.cv1(x))
        return x + y if self.add else y


class C3k2(nn.Module):
    def __init__(self, in_c, out_c, n=2, shortcut=True, e=0.5):
        super().__init__()
        self.c = int(out_c * e)
        self.cv1 = ConvBNSiLU(in_c, 2 * self.c, 1)
        self.cv2 = ConvBNSiLU((2 + n) * self.c, out_c, 1)
        self.bottlenecks = nn.ModuleList(
            Bottleneck(self.c, self.c, shortcut) for _ in range(n)
        )

    def forward(self, x):
        y = list(self.cv1(x).chunk(2, dim=1))
        for bn in self.bottlenecks:
            y.append(bn(y[-1]))
        return self.cv2(torch.cat(y, dim=1))


class SPPF(nn.Module):
    def __init__(self, in_c, out_c, k=5):
        super().__init__()
        mid = in_c // 2
        self.cv1 = ConvBNSiLU(in_c, mid, 1)
        self.cv2 = ConvBNSiLU(mid * 4, out_c, 1)
        self.pool = nn.MaxPool2d(k, stride=1, padding=k // 2)

    def forward(self, x):
        x = self.cv1(x)
        y1 = self.pool(x)
        y2 = self.pool(y1)
        y3 = self.pool(y2)
        return self.cv2(torch.cat([x, y1, y2, y3], dim=1))


class YOLOv11Backbone(nn.Module):
    def __init__(self, in_channels=3, base_channels=64):
        super().__init__()
        c1, c2, c3, c4, c5 = (
            base_channels, base_channels * 2, base_channels * 4,
            base_channels * 8, base_channels * 16,
        )
        self.stem = ConvBNSiLU(in_channels, c1, 3, 2)
        self.stage1 = nn.Sequential(ConvBNSiLU(c1, c2, 3, 2), C3k2(c2, c2, n=2))
        self.stage2 = nn.Sequential(ConvBNSiLU(c2, c3, 3, 2), C3k2(c3, c3, n=2))
        self.stage3 = nn.Sequential(ConvBNSiLU(c3, c4, 3, 2), C3k2(c4, c4, n=2))
        self.stage4 = nn.Sequential(ConvBNSiLU(c4, c5, 3, 2), C3k2(c5, c5, n=2), SPPF(c5, c5))

    def forward(self, x):
        x = self.stem(x)
        x = self.stage1(x)
        p3 = self.stage2(x)
        p4 = self.stage3(p3)
        p5 = self.stage4(p4)
        return p3, p4, p5

FPN: Top-Down Path

The Feature Pyramid Network (FPN) implements a top-down pathway that enriches lower-resolution, semantically strong features with higher-resolution spatial information. The process works as follows:
  1. P5 is upsampled (nearest-neighbor interpolation) and concatenated with P4. A C3k2 block fuses the concatenated features.
  2. The fused P4 is upsampled and concatenated with P3. Another C3k2 block produces the final FPN P3 output.
After the FPN, every level in the pyramid carries both high-level semantics from deeper layers and fine-grained spatial detail from shallower layers.
class FPN(nn.Module):
    """Feature Pyramid Network - top-down pathway.

    Fuses high-level semantic features (P5) with lower-level features (P4, P3)
    through upsampling and concatenation.
    """

    def __init__(self, c3: int = 256, c4: int = 512, c5: int = 1024):
        super().__init__()
        # Lateral convolutions to reduce channel dims
        self.lateral_p5 = ConvBNSiLU(c5, c4, 1)
        self.lateral_p4 = ConvBNSiLU(c4, c3, 1)

        # C3k2 blocks after concatenation
        self.fpn_p4 = C3k2(c4 + c4, c4, n=2, shortcut=False)
        self.fpn_p3 = C3k2(c3 + c3, c3, n=2, shortcut=False)

    def forward(self, p3, p4, p5):
        # Top-down: P5 -> P4
        p5_up = F.interpolate(self.lateral_p5(p5), size=p4.shape[2:], mode='nearest')
        fpn_p4 = self.fpn_p4(torch.cat([p5_up, p4], dim=1))

        # Top-down: P4 -> P3
        p4_up = F.interpolate(self.lateral_p4(fpn_p4), size=p3.shape[2:], mode='nearest')
        fpn_p3 = self.fpn_p3(torch.cat([p4_up, p3], dim=1))

        return fpn_p3, fpn_p4, p5

PAN: Bottom-Up Path

The Path Aggregation Network (PAN) complements the FPN with a bottom-up pathway. While FPN carries semantic information downward, PAN carries strong localization features back upward:
  1. FPN P3 is downsampled (stride-2 convolution) and concatenated with FPN P4. A C3k2 block fuses them.
  2. The fused P4 is downsampled and concatenated with P5. Another C3k2 block produces the final PAN P5 output.
The combination of FPN + PAN creates a bidirectional feature fusion pathway. Every scale level now benefits from both high-level category semantics and low-level localization accuracy.
class PAN(nn.Module):
    """Path Aggregation Network - bottom-up pathway.

    Fuses strong localization features from lower levels back up,
    complementing FPN's semantic fusion.
    """

    def __init__(self, c3: int = 256, c4: int = 512, c5: int = 1024):
        super().__init__()
        # Downsample convolutions (k=3, s=2)
        self.down_p3 = ConvBNSiLU(c3, c3, 3, 2)
        self.down_p4 = ConvBNSiLU(c4, c4, 3, 2)

        # C3k2 blocks after concatenation
        self.pan_p4 = C3k2(c3 + c4, c4, n=2, shortcut=False)
        self.pan_p5 = C3k2(c4 + c5, c5, n=2, shortcut=False)

    def forward(self, fpn_p3, fpn_p4, p5):
        # Bottom-up: P3 -> P4
        p3_down = self.down_p3(fpn_p3)
        pan_p4 = self.pan_p4(torch.cat([p3_down, fpn_p4], dim=1))

        # Bottom-up: P4 -> P5
        p4_down = self.down_p4(pan_p4)
        pan_p5 = self.pan_p5(torch.cat([p4_down, p5], dim=1))

        return fpn_p3, pan_p4, pan_p5

C2PSA: Channel Attention Block

Partial Self-Attention (PSA) is a lightweight attention mechanism introduced in YOLOv11 that selectively emphasizes the most informative channels in a feature map while suppressing noise. The C2PSA block follows the CSP (Cross Stage Partial) pattern:
  1. Split the input channels into two halves.
  2. Process one half through bottleneck layers with channel attention (squeeze-excitation style).
  3. Concatenate the unprocessed half with the attended output.
This design keeps computational cost low while still providing the network with a learned mechanism for feature selection. In YOLOv11, C2PSA is applied to the deepest feature map (P5) where the receptive field is largest and attention is most beneficial.
class C2PSA(nn.Module):
    """CSP block with Partial Self-Attention for feature refinement.

    Applies channel attention to selectively emphasize informative features
    while suppressing noise, improving detection at all scales.
    """

    def __init__(self, in_channels: int, out_channels: int, n: int = 1):
        super().__init__()
        self.c = in_channels // 2
        self.cv1 = ConvBNSiLU(in_channels, 2 * self.c, 1)
        self.cv2 = ConvBNSiLU(2 * self.c, out_channels, 1)

        # Attention blocks
        self.attention = nn.ModuleList([
            nn.Sequential(
                nn.AdaptiveAvgPool2d(1),
                nn.Flatten(),
                nn.Linear(self.c, self.c // 4),
                nn.SiLU(inplace=True),
                nn.Linear(self.c // 4, self.c),
                nn.Sigmoid()
            ) for _ in range(n)
        ])
        self.bottlenecks = nn.ModuleList([
            Bottleneck(self.c, self.c, shortcut=True) for _ in range(n)
        ])

    def forward(self, x):
        y = list(self.cv1(x).chunk(2, dim=1))
        for attn, bn in zip(self.attention, self.bottlenecks):
            feat = bn(y[-1])
            # Channel attention
            att_weights = attn(feat).unsqueeze(-1).unsqueeze(-1)
            feat = feat * att_weights
            y.append(feat)
        # Use only first split and last bottleneck output
        return self.cv2(torch.cat([y[0], y[-1]], dim=1))

DFL: Distribution Focal Loss Head

Traditional object detectors regress four continuous values (e.g., center offsets and width/height) for each bounding box. Distribution Focal Loss (DFL) takes a fundamentally different approach: instead of predicting a single scalar per box boundary, the network predicts a discrete probability distribution over a set of reg_max bins.

Why distributions instead of scalars?

  • Ambiguity modeling: Object boundaries are often ambiguous (occlusion, blur). A distribution naturally represents this uncertainty.
  • Better optimization: The softmax-based formulation provides smoother gradients than direct regression.
  • Improved small-object accuracy: The expected-value computation gives sub-bin precision.

How it works

For each of the 4 box boundaries (left, top, right, bottom):
  1. The network outputs reg_max logits.
  2. A softmax converts them to a probability distribution.
  3. The final offset is the expected value: i=0reg_max1iP(i)\sum_{i=0}^{\text{reg\_max}-1} i \cdot P(i).
This gives the model the expressiveness of a full distribution while producing a single precise offset for each boundary.

Decoupled Detection Head

class DFLHead(nn.Module):
    """Distribution Focal Loss module for precise box regression.

    Predicts a discrete probability distribution over reg_max bins
    for each of the 4 box boundaries (left, top, right, bottom).
    The expected value gives the final regression offset.
    """

    def __init__(self, reg_max: int = 16):
        super().__init__()
        self.reg_max = reg_max
        # Project: weight vector [0, 1, ..., reg_max-1]
        self.register_buffer('project', torch.arange(reg_max, dtype=torch.float32))

    def forward(self, x):
        """x: (B, 4 * reg_max, H, W) -> (B, 4, H, W) box offsets."""
        b, _, h, w = x.shape
        # Reshape to (B, 4, reg_max, H, W)
        x = x.view(b, 4, self.reg_max, h, w)
        # Softmax over reg_max dimension
        x = F.softmax(x, dim=2)
        # Expected value: weighted sum with [0, 1, ..., reg_max-1]
        x = (x * self.project.view(1, 1, -1, 1, 1)).sum(dim=2)
        return x


class DetectionHead(nn.Module):
    """Decoupled anchor-free detection head for one scale level.

    Separate classification and regression branches with DFL
    for precise localization.
    """

    def __init__(self, in_channels: int, num_classes: int = 80, reg_max: int = 16):
        super().__init__()
        self.num_classes = num_classes
        self.reg_max = reg_max

        # Classification branch
        self.cls_branch = nn.Sequential(
            ConvBNSiLU(in_channels, in_channels, 3),
            ConvBNSiLU(in_channels, in_channels, 3),
            nn.Conv2d(in_channels, num_classes, 1)
        )

        # Regression branch (DFL)
        self.reg_branch = nn.Sequential(
            ConvBNSiLU(in_channels, in_channels, 3),
            ConvBNSiLU(in_channels, in_channels, 3),
            nn.Conv2d(in_channels, 4 * reg_max, 1)
        )

        self.dfl = DFLHead(reg_max)

    def forward(self, x):
        """
        Args:
            x: (B, C, H, W) feature map from neck
        Returns:
            cls_pred: (B, num_classes, H, W) class logits
            box_pred: (B, 4, H, W) decoded box offsets (ltrb)
            box_raw: (B, 4*reg_max, H, W) raw DFL logits (for loss)
        """
        cls_pred = self.cls_branch(x)
        box_raw = self.reg_branch(x)
        box_pred = self.dfl(box_raw)
        return cls_pred, box_pred, box_raw

Full YOLOv11 Model Assembly

We now assemble the complete YOLOv11 model by combining backbone, FPN + PAN neck, C2PSA attention, and decoupled detection heads into a single end-to-end module.
class YOLOv11(nn.Module):
    """Complete YOLOv11 object detection model.

    Assembles backbone, FPN+PAN neck, and decoupled detection heads
    into a single end-to-end model.
    """

    def __init__(self, num_classes: int = 80, reg_max: int = 16, base_channels: int = 64):
        super().__init__()
        c3, c4, c5 = base_channels * 4, base_channels * 8, base_channels * 16

        self.backbone = YOLOv11Backbone(base_channels=base_channels)
        self.fpn = FPN(c3, c4, c5)
        self.pan = PAN(c3, c4, c5)

        # Optional C2PSA attention on P5
        self.c2psa = C2PSA(c5, c5, n=1)

        # Detection heads - one per scale
        self.head_p3 = DetectionHead(c3, num_classes, reg_max)
        self.head_p4 = DetectionHead(c4, num_classes, reg_max)
        self.head_p5 = DetectionHead(c5, num_classes, reg_max)

        self.strides = [8, 16, 32]
        self.num_classes = num_classes
        self.reg_max = reg_max

    def forward(self, x):
        """
        Args:
            x: (B, 3, 640, 640) input images
        Returns:
            predictions: list of (cls_pred, box_pred, box_raw) per scale
        """
        # Backbone
        p3, p4, p5 = self.backbone(x)

        # Neck: FPN + PAN
        fpn_p3, fpn_p4, fpn_p5 = self.fpn(p3, p4, p5)
        pan_p3, pan_p4, pan_p5 = self.pan(fpn_p3, fpn_p4, fpn_p5)

        # Attention on P5
        pan_p5 = self.c2psa(pan_p5)

        # Detection heads
        pred_p3 = self.head_p3(pan_p3)
        pred_p4 = self.head_p4(pan_p4)
        pred_p5 = self.head_p5(pan_p5)

        return [pred_p3, pred_p4, pred_p5]

Shape Verification

Let us verify that the full model produces outputs of the expected shapes at each scale level.
model = YOLOv11(num_classes=80)
x = torch.randn(2, 3, 640, 640)

with torch.no_grad():
    predictions = model(x)

print("Input shape:", x.shape)
print()
for i, (cls_pred, box_pred, box_raw) in enumerate(predictions):
    stride = model.strides[i]
    print(f"P{i+3} (stride {stride}):")
    print(f"  cls_pred: {cls_pred.shape}  (B, {model.num_classes} classes, H, W)")
    print(f"  box_pred: {box_pred.shape}  (B, 4 ltrb offsets, H, W)")
    print(f"  box_raw:  {box_raw.shape}  (B, 4x{model.reg_max} DFL bins, H, W)")
Input shape: torch.Size([2, 3, 640, 640])

P3 (stride 8):
  cls_pred: torch.Size([2, 80, 80, 80])  (B, 80 classes, H, W)
  box_pred: torch.Size([2, 4, 80, 80])  (B, 4 ltrb offsets, H, W)
  box_raw:  torch.Size([2, 64, 80, 80])  (B, 4x16 DFL bins, H, W)
P4 (stride 16):
  cls_pred: torch.Size([2, 80, 40, 40])  (B, 80 classes, H, W)
  box_pred: torch.Size([2, 4, 40, 40])  (B, 4 ltrb offsets, H, W)
  box_raw:  torch.Size([2, 64, 40, 40])  (B, 4x16 DFL bins, H, W)
P5 (stride 32):
  cls_pred: torch.Size([2, 80, 20, 20])  (B, 80 classes, H, W)
  box_pred: torch.Size([2, 4, 20, 20])  (B, 4 ltrb offsets, H, W)
  box_raw:  torch.Size([2, 64, 20, 20])  (B, 4x16 DFL bins, H, W)

Parameter Count

def count_params(model, name="Model"):
    total = sum(p.numel() for p in model.parameters())
    print(f"\n{name}: {total:,} parameters ({total * 4 / 1024**2:.1f} MB)")
    return total

print("=== YOLOv11 Parameter Breakdown ===")
for name, module in model.named_children():
    count_params(module, name)
count_params(model, "Total YOLOv11")
=== YOLOv11 Parameter Breakdown ===

backbone: 19,355,328 parameters (73.8 MB)

fpn: 3,447,552 parameters (13.2 MB)

pan: 13,447,168 parameters (51.3 MB)

c2psa: 4,593,792 parameters (17.5 MB)

head_p3: 2,398,352 parameters (9.1 MB)

head_p4: 9,515,152 parameters (36.3 MB)

head_p5: 37,904,528 parameters (144.6 MB)

Total YOLOv11: 90,661,872 parameters (345.8 MB)
90661872

Architecture Visualization

The following diagram illustrates the complete information flow through the YOLOv11 architecture: backbone feature extraction, FPN top-down fusion, PAN bottom-up fusion, and the per-scale detection heads.
def visualize_architecture():
    fig, ax = plt.subplots(figsize=(14, 10))
    ax.set_xlim(0, 14)
    ax.set_ylim(0, 10)
    ax.axis('off')

    # Draw backbone
    backbone_boxes = [
        (1, 8.5, 'Stem\n3->64', 'lightblue'),
        (1, 7.0, 'Stage1\n64->128', 'lightblue'),
        (1, 5.5, 'Stage2\n128->256', 'lightblue'),
        (1, 4.0, 'Stage3\n256->512', 'lightblue'),
        (1, 2.5, 'Stage4+SPPF\n512->1024', 'lightblue'),
    ]

    for bx, by, text, color in backbone_boxes:
        rect = plt.Rectangle((bx - 0.7, by - 0.5), 1.4, 0.8,
                             facecolor=color, edgecolor='black', linewidth=1.5)
        ax.add_patch(rect)
        ax.text(bx, by, text, ha='center', va='center', fontsize=8, fontweight='bold')

    # FPN arrows and boxes
    fpn_x = 5
    ax.text(fpn_x, 9.2, 'FPN (Top-Down)', ha='center', fontsize=11,
            fontweight='bold', color='green')

    fpn_boxes = [
        (fpn_x, 5.5, 'FPN P3\n256', 'lightgreen'),
        (fpn_x, 4.0, 'FPN P4\n512', 'lightgreen'),
        (fpn_x, 2.5, 'P5\n1024', 'lightgreen'),
    ]
    for bx, by, text, color in fpn_boxes:
        rect = plt.Rectangle((bx - 0.7, by - 0.5), 1.4, 0.8,
                             facecolor=color, edgecolor='black', linewidth=1.5)
        ax.add_patch(rect)
        ax.text(bx, by, text, ha='center', va='center', fontsize=8, fontweight='bold')

    # PAN boxes
    pan_x = 9
    ax.text(pan_x, 9.2, 'PAN (Bottom-Up)', ha='center', fontsize=11,
            fontweight='bold', color='orange')

    pan_boxes = [
        (pan_x, 5.5, 'PAN P3\n256', 'moccasin'),
        (pan_x, 4.0, 'PAN P4\n512', 'moccasin'),
        (pan_x, 2.5, 'PAN P5\n1024', 'moccasin'),
    ]
    for bx, by, text, color in pan_boxes:
        rect = plt.Rectangle((bx - 0.7, by - 0.5), 1.4, 0.8,
                             facecolor=color, edgecolor='black', linewidth=1.5)
        ax.add_patch(rect)
        ax.text(bx, by, text, ha='center', va='center', fontsize=8, fontweight='bold')

    # Heads
    head_x = 12.5
    ax.text(head_x, 9.2, 'Detection Heads', ha='center', fontsize=11,
            fontweight='bold', color='red')

    head_boxes = [
        (head_x, 5.5, 'Head P3\ncls+reg', 'lightyellow'),
        (head_x, 4.0, 'Head P4\ncls+reg', 'lightyellow'),
        (head_x, 2.5, 'Head P5\ncls+reg', 'lightyellow'),
    ]
    for bx, by, text, color in head_boxes:
        rect = plt.Rectangle((bx - 0.8, by - 0.5), 1.6, 0.8,
                             facecolor=color, edgecolor='black', linewidth=1.5)
        ax.add_patch(rect)
        ax.text(bx, by, text, ha='center', va='center', fontsize=8, fontweight='bold')

    ax.set_title('YOLOv11 Architecture: Backbone -> FPN -> PAN -> Heads',
                 fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()

visualize_architecture()
Output from cell 10

Summary

In this notebook we built the complete YOLOv11 architecture on top of the backbone from Notebook 2:
  • FPN (top-down) propagates high-level semantics from P5 down to P3, ensuring that every scale level understands what objects are present.
  • PAN (bottom-up) propagates strong localization features from P3 back up to P5, ensuring that every scale level knows where objects are.
  • C2PSA applies lightweight channel attention to P5, allowing the network to selectively emphasize the most informative features.
  • DFL (Distribution Focal Loss) replaces direct box regression with a discrete distribution over offsets, enabling more precise localization---especially for small objects.
  • Decoupled detection heads separate classification and regression into independent branches, allowing each task to specialize without interfering with the other.
The model produces predictions at three scales (P3/P4/P5 with strides 8/16/32), covering objects from small to large. Next: Notebook 4 covers the loss functions used to train this model, including the task-aligned assigner, classification loss (BCE), box regression loss (CIoU + DFL), and the overall multi-task training objective.