Skip to main content
Open In Colab

Loss Functions, Label Assignment, and Training

Notebook 4 of 5 in the YOLOv11 from-scratch series

Introduction

With our model architecture complete (backbone, neck, and head from Notebooks 2-3), we now face three critical challenges that determine whether the detector actually learns to find objects:
  1. IoU computation - How do we measure the geometric overlap between predicted and ground-truth boxes? The choice of IoU variant directly affects gradient quality and convergence speed.
  2. Label assignment strategy - Given thousands of anchor points but only a handful of ground-truth boxes per image, which anchors should be responsible for predicting each object? This is the assignment problem.
  3. Loss function design - How do we combine classification, localization, and distribution regression objectives into a single scalar loss that balances all three tasks?
YOLOv11 addresses these with:
  • CIoU (Complete IoU) for box regression, which captures overlap, center distance, and aspect ratio in a single differentiable metric
  • Task-Aligned Learning (TAL) for label assignment, which selects anchors based on both classification confidence and localization quality
  • A composite loss combining BCE classification, CIoU box regression, and Distribution Focal Loss (DFL)
We will build each component from scratch and run a small training loop on synthetic data to verify that the entire pipeline works end-to-end.

Imports

import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import math
from typing import List, Tuple, Optional

from datasets import load_dataset
from PIL import Image
/workspaces/eng-ai-agents/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm

Model components from Notebooks 2-3

To keep this notebook self-contained, we re-define all model components (backbone, neck, head) in a single compact cell. These are identical to the implementations in Notebooks 2 and 3. Refer to those notebooks for detailed explanations of each block.
class ConvBNSiLU(nn.Module):
    def __init__(self, c_in, c_out, k=1, s=1, p=None, g=1):
        super().__init__()
        if p is None: p = k // 2
        self.conv = nn.Conv2d(c_in, c_out, k, s, p, groups=g, bias=False)
        self.bn = nn.BatchNorm2d(c_out)
        self.act = nn.SiLU(inplace=True)
    def forward(self, x): return self.act(self.bn(self.conv(x)))

class Bottleneck(nn.Module):
    def __init__(self, c_in, c_out, shortcut=True, k=(3,3), e=0.5):
        super().__init__()
        c_hid = int(c_out * e)
        self.cv1 = ConvBNSiLU(c_in, c_hid, k[0])
        self.cv2 = ConvBNSiLU(c_hid, c_out, k[1])
        self.add = shortcut and c_in == c_out
    def forward(self, x): return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))

class C3k2(nn.Module):
    def __init__(self, c_in, c_out, n=1, shortcut=True, e=0.5):
        super().__init__()
        self.c = int(c_out * e)
        self.cv1 = ConvBNSiLU(c_in, 2 * self.c, 1)
        self.cv2 = ConvBNSiLU((2 + n) * self.c, c_out, 1)
        self.bottlenecks = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, (3,3), 1.0) for _ in range(n))
    def forward(self, x):
        y = list(self.cv1(x).chunk(2, dim=1))
        for bn in self.bottlenecks: y.append(bn(y[-1]))
        return self.cv2(torch.cat(y, dim=1))

class SPPF(nn.Module):
    def __init__(self, c_in, c_out, k=5):
        super().__init__()
        c_hid = c_in // 2
        self.cv1 = ConvBNSiLU(c_in, c_hid, 1)
        self.cv2 = ConvBNSiLU(c_hid * 4, c_out, 1)
        self.pool = nn.MaxPool2d(k, stride=1, padding=k // 2)
    def forward(self, x):
        x = self.cv1(x); y1 = self.pool(x); y2 = self.pool(y1); y3 = self.pool(y2)
        return self.cv2(torch.cat([x, y1, y2, y3], dim=1))

class YOLOv11Backbone(nn.Module):
    def __init__(self, c_in=3, base=64):
        super().__init__()
        c1, c2, c3, c4, c5 = base, base*2, base*4, base*8, base*16
        self.stem = ConvBNSiLU(c_in, c1, 3, s=2)
        self.s1_down = ConvBNSiLU(c1, c2, 3, s=2); self.s1_c3k2 = C3k2(c2, c2, n=2)
        self.s2_down = ConvBNSiLU(c2, c3, 3, s=2); self.s2_c3k2 = C3k2(c3, c3, n=2)
        self.s3_down = ConvBNSiLU(c3, c4, 3, s=2); self.s3_c3k2 = C3k2(c4, c4, n=2)
        self.s4_down = ConvBNSiLU(c4, c5, 3, s=2); self.s4_c3k2 = C3k2(c5, c5, n=2)
        self.sppf = SPPF(c5, c5)
    def forward(self, x):
        x = self.stem(x)
        x = self.s1_c3k2(self.s1_down(x))
        p3 = self.s2_c3k2(self.s2_down(x))
        p4 = self.s3_c3k2(self.s3_down(p3))
        p5 = self.sppf(self.s4_c3k2(self.s4_down(p4)))
        return p3, p4, p5

class FPN(nn.Module):
    def __init__(self, c3=256, c4=512, c5=1024):
        super().__init__()
        self.lateral5 = ConvBNSiLU(c5, c4, 1)
        self.lateral4 = ConvBNSiLU(c4, c3, 1)
        self.up = nn.Upsample(scale_factor=2, mode='nearest')
        self.fuse4 = C3k2(c4 * 2, c4, n=2)
        self.fuse3 = C3k2(c3 * 2, c3, n=2)
    def forward(self, p3, p4, p5):
        p5_up = self.up(self.lateral5(p5))
        p4 = self.fuse4(torch.cat([p4, p5_up], dim=1))
        p4_up = self.up(self.lateral4(p4))
        p3 = self.fuse3(torch.cat([p3, p4_up], dim=1))
        return p3, p4, p5

class PAN(nn.Module):
    def __init__(self, c3=256, c4=512, c5=1024):
        super().__init__()
        self.down3 = ConvBNSiLU(c3, c3, 3, s=2)
        self.fuse4 = C3k2(c3 + c4, c4, n=2)
        self.down4 = ConvBNSiLU(c4, c4, 3, s=2)
        self.fuse5 = C3k2(c4 + c5, c5, n=2)
    def forward(self, p3, p4, p5):
        p4 = self.fuse4(torch.cat([self.down3(p3), p4], dim=1))
        p5 = self.fuse5(torch.cat([self.down4(p4), p5], dim=1))
        return p3, p4, p5

class C2PSA(nn.Module):
    def __init__(self, c, n_heads=8):
        super().__init__()
        self.cv1 = ConvBNSiLU(c, c, 1)
        self.attn = nn.MultiheadAttention(c, n_heads, batch_first=True)
        self.ffn = nn.Sequential(ConvBNSiLU(c, c * 2, 1), ConvBNSiLU(c * 2, c, 1))
        self.cv2 = ConvBNSiLU(c, c, 1)
    def forward(self, x):
        y = self.cv1(x)
        B, C, H, W = y.shape
        flat = y.flatten(2).permute(0, 2, 1)
        flat = flat + self.attn(flat, flat, flat, need_weights=False)[0]
        y = flat.permute(0, 2, 1).view(B, C, H, W)
        y = y + self.ffn(y)
        return self.cv2(y)

class DFLHead(nn.Module):
    def __init__(self, c_in, num_classes=80, reg_max=16):
        super().__init__()
        self.reg_max = reg_max
        self.num_classes = num_classes
        self.cls_convs = nn.Sequential(ConvBNSiLU(c_in, c_in, 3), ConvBNSiLU(c_in, c_in, 3))
        self.reg_convs = nn.Sequential(ConvBNSiLU(c_in, c_in, 3), ConvBNSiLU(c_in, c_in, 3))
        self.cls_pred = nn.Conv2d(c_in, num_classes, 1)
        self.reg_pred = nn.Conv2d(c_in, 4 * reg_max, 1)
        self.proj = nn.Parameter(torch.arange(reg_max, dtype=torch.float32), requires_grad=False)
    def forward(self, x):
        cls_out = self.cls_pred(self.cls_convs(x))
        reg_raw = self.reg_pred(self.reg_convs(x))
        B, _, H, W = reg_raw.shape
        reg_dist = reg_raw.view(B, 4, self.reg_max, H, W)
        reg_box = F.softmax(reg_dist, dim=2)
        reg_box = (reg_box * self.proj.view(1, 1, -1, 1, 1)).sum(dim=2)
        return cls_out, reg_box, reg_raw

class DetectionHead(nn.Module):
    def __init__(self, channels=[256, 512, 1024], num_classes=80, reg_max=16):
        super().__init__()
        self.heads = nn.ModuleList([DFLHead(c, num_classes, reg_max) for c in channels])
    def forward(self, features):
        return [head(f) for head, f in zip(self.heads, features)]

class YOLOv11(nn.Module):
    def __init__(self, num_classes=80, reg_max=16):
        super().__init__()
        self.backbone = YOLOv11Backbone()
        self.fpn = FPN()
        self.pan = PAN()
        self.c2psa = C2PSA(1024)
        self.head = DetectionHead(num_classes=num_classes, reg_max=reg_max)
    def forward(self, x):
        p3, p4, p5 = self.backbone(x)
        p5 = self.c2psa(p5)
        p3, p4, p5 = self.fpn(p3, p4, p5)
        p3, p4, p5 = self.pan(p3, p4, p5)
        return self.head([p3, p4, p5])

print("Model components loaded successfully.")
print(f"YOLOv11 parameters: {sum(p.numel() for p in YOLOv11(num_classes=80).parameters()):,}")
Model components loaded successfully.
YOLOv11 parameters: 110,212,128

The evolution of IoU metrics

Intersection over Union (IoU) is the foundational metric for measuring bounding box quality. However, the basic IoU has significant limitations that led to a series of improvements:

IoU (Intersection over Union)

IoU=BpBgtBpBgt\text{IoU} = \frac{|B_p \cap B_{gt}|}{|B_p \cup B_{gt}|} Simple and intuitive, but has a critical flaw: when two boxes do not overlap, IoU is zero regardless of how far apart they are. This means zero gradient for non-overlapping predictions, making it useless as a standalone loss for poorly initialized detectors.

GIoU (Generalized IoU)

GIoU=IoUC(BpBgt)C\text{GIoU} = \text{IoU} - \frac{|C \setminus (B_p \cup B_{gt})|}{|C|} where CC is the smallest enclosing box. GIoU adds a penalty for the gap between the predicted and ground-truth boxes. It provides gradients even when boxes do not overlap, but converges slowly because it only penalizes the empty area ratio, not the distance directly.

DIoU (Distance IoU)

DIoU=IoUd2(bp,bgt)c2\text{DIoU} = \text{IoU} - \frac{d^2(\mathbf{b}_p, \mathbf{b}_{gt})}{c^2} where dd is the Euclidean distance between box centers and cc is the diagonal of the enclosing box. By directly penalizing center-point distance, DIoU converges much faster than GIoU.

CIoU (Complete IoU)

CIoU=IoUd2(bp,bgt)c2αv\text{CIoU} = \text{IoU} - \frac{d^2(\mathbf{b}_p, \mathbf{b}_{gt})}{c^2} - \alpha v where v=4π2(arctanwgthgtarctanwphp)2v = \frac{4}{\pi^2}\left(\arctan\frac{w_{gt}}{h_{gt}} - \arctan\frac{w_p}{h_p}\right)^2 measures aspect ratio consistency, and α=v(1IoU)+v\alpha = \frac{v}{(1 - \text{IoU}) + v} is an adaptive weight. CIoU provides complete geometric alignment by considering overlap, center distance, and aspect ratio simultaneously. This is what YOLOv11 uses for box regression.

IoU implementations

def compute_iou(box1, box2, mode='ciou', eps=1e-7):
    """Compute IoU variants between two sets of boxes.
    
    Args:
        box1: (N, 4) in [x1, y1, x2, y2] format
        box2: (M, 4) in [x1, y1, x2, y2] format
        mode: 'iou', 'giou', 'diou', or 'ciou'
    Returns:
        iou: (N, M) pairwise IoU values
    """
    # Intersection
    inter_x1 = torch.max(box1[:, None, 0], box2[None, :, 0])
    inter_y1 = torch.max(box1[:, None, 1], box2[None, :, 1])
    inter_x2 = torch.min(box1[:, None, 2], box2[None, :, 2])
    inter_y2 = torch.min(box1[:, None, 3], box2[None, :, 3])
    inter = (inter_x2 - inter_x1).clamp(0) * (inter_y2 - inter_y1).clamp(0)
    
    # Union
    area1 = (box1[:, 2] - box1[:, 0]) * (box1[:, 3] - box1[:, 1])
    area2 = (box2[:, 2] - box2[:, 0]) * (box2[:, 3] - box2[:, 1])
    union = area1[:, None] + area2[None, :] - inter
    
    iou = inter / (union + eps)
    
    if mode == 'iou':
        return iou
    
    # Enclosing box
    enc_x1 = torch.min(box1[:, None, 0], box2[None, :, 0])
    enc_y1 = torch.min(box1[:, None, 1], box2[None, :, 1])
    enc_x2 = torch.max(box1[:, None, 2], box2[None, :, 2])
    enc_y2 = torch.max(box1[:, None, 3], box2[None, :, 3])
    enc_area = (enc_x2 - enc_x1) * (enc_y2 - enc_y1)
    
    if mode == 'giou':
        return iou - (enc_area - union) / (enc_area + eps)
    
    # Center distance
    cx1 = (box1[:, 0] + box1[:, 2]) / 2
    cy1 = (box1[:, 1] + box1[:, 3]) / 2
    cx2 = (box2[:, 0] + box2[:, 2]) / 2
    cy2 = (box2[:, 1] + box2[:, 3]) / 2
    
    center_dist = (cx1[:, None] - cx2[None, :]) ** 2 + (cy1[:, None] - cy2[None, :]) ** 2
    diag_dist = (enc_x2 - enc_x1) ** 2 + (enc_y2 - enc_y1) ** 2
    
    if mode == 'diou':
        return iou - center_dist / (diag_dist + eps)
    
    # CIoU: aspect ratio penalty
    w1 = box1[:, 2] - box1[:, 0]
    h1 = box1[:, 3] - box1[:, 1]
    w2 = box2[:, 2] - box2[:, 0]
    h2 = box2[:, 3] - box2[:, 1]
    
    v = (4 / math.pi ** 2) * (
        torch.atan(w2[None, :] / (h2[None, :] + eps)) - 
        torch.atan(w1[:, None] / (h1[:, None] + eps))
    ) ** 2
    
    with torch.no_grad():
        alpha = v / (1 - iou + v + eps)
    
    return iou - center_dist / (diag_dist + eps) - alpha * v

Visualizing IoU variants

To build intuition for how these metrics differ, we slide a prediction box horizontally away from a fixed ground-truth box and plot each IoU variant. Notice how:
  • IoU drops to zero once boxes separate and stays there (no gradient signal)
  • GIoU continues to decrease below zero but slowly
  • DIoU decreases more steeply due to the direct distance penalty
  • CIoU behaves like DIoU here (same aspect ratio), but would differ for shape changes
def visualize_iou_variants():
    fig, axes = plt.subplots(1, 4, figsize=(20, 5))
    
    # Fixed reference box
    ref = torch.tensor([[2.0, 2.0, 5.0, 5.0]])
    
    # Move a box horizontally
    offsets = torch.linspace(0, 6, 50)
    modes = ['iou', 'giou', 'diou', 'ciou']
    
    for ax, mode in zip(axes, modes):
        values = []
        for dx in offsets:
            pred = torch.tensor([[2.0 + dx.item(), 2.0, 5.0 + dx.item(), 5.0]])
            val = compute_iou(ref, pred, mode=mode)
            values.append(val.item())
        ax.plot(offsets.numpy(), values, linewidth=2)
        ax.set_title(mode.upper(), fontsize=14, fontweight='bold')
        ax.set_xlabel('Horizontal offset')
        ax.set_ylabel(f'{mode} value')
        ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
        ax.grid(True, alpha=0.3)
    
    plt.suptitle('IoU Variants: Response to Horizontal Box Translation', fontsize=14)
    plt.tight_layout()
    plt.show()

visualize_iou_variants()
Output from cell 4

Task-Aligned Learning (TAL)

Label assignment is the bridge between ground-truth annotations and the thousands of predictions a detector makes. For each ground-truth box, we need to decide which anchor points are “responsible” for predicting it.

The problem with simpler strategies

  • IoU-based assignment (used in earlier YOLO versions): assigns anchors based purely on spatial overlap with GT. This ignores whether the model is actually confident about the prediction, leading to misalignment between assignment and model capacity.
  • Center-based assignment (e.g., FCOS): assigns all anchors whose centers fall inside the GT box. Simple but does not consider prediction quality.

Task-Aligned Learning

TAL resolves this by computing an alignment metric that combines both classification and localization quality: t=sαuβt = s^\alpha \cdot u^\beta where:
  • ss is the predicted classification score for the GT class
  • uu is the IoU between the predicted box and the GT box
  • α=1.0\alpha = 1.0 and β=6.0\beta = 6.0 control the relative importance (localization is weighted much more heavily)
The assignment procedure:
  1. Filter to anchors whose centers lie inside each GT box
  2. Compute the alignment metric tt for all valid anchor-GT pairs
  3. Select the top-k anchors (default k=13k=13) per GT based on tt
  4. Resolve conflicts (anchor assigned to multiple GTs) by keeping the highest-alignment GT
  5. Generate soft label targets by normalizing the alignment scores
This approach is more effective because it assigns labels to anchors that the model is already doing well on, creating a positive feedback loop that accelerates training.
class TaskAlignedAssigner:
    """Task-Aligned Label Assignment for anchor-free detection.
    
    Assigns ground truth to predictions using alignment metric
    that considers both classification confidence and box IoU.
    """
    
    def __init__(self, topk: int = 13, alpha: float = 1.0, beta: float = 6.0, eps: float = 1e-9):
        self.topk = topk
        self.alpha = alpha
        self.beta = beta
        self.eps = eps
    
    @torch.no_grad()
    def assign(self, pred_scores, pred_bboxes, gt_labels, gt_bboxes, anchor_points, stride):
        """
        Args:
            pred_scores: (num_anchors, num_classes) predicted class scores (sigmoid)
            pred_bboxes: (num_anchors, 4) predicted boxes [x1, y1, x2, y2]
            gt_labels: (num_gt,) ground truth class indices
            gt_bboxes: (num_gt, 4) ground truth boxes [x1, y1, x2, y2]
            anchor_points: (num_anchors, 2) anchor center positions
            stride: feature stride for this level
        Returns:
            assigned_labels: (num_anchors,) -1 for background
            assigned_bboxes: (num_anchors, 4) 
            assigned_scores: (num_anchors, num_classes) soft labels
        """
        device = pred_scores.device
        num_anchors = pred_scores.shape[0]
        num_gt = gt_bboxes.shape[0]
        num_classes = pred_scores.shape[1]
        
        if num_gt == 0:
            return (
                torch.full((num_anchors,), -1, dtype=torch.long, device=device),
                torch.zeros((num_anchors, 4), device=device),
                torch.zeros((num_anchors, num_classes), device=device)
            )
        
        # Check if anchor centers fall inside GT boxes
        # anchor_points: (num_anchors, 2) [cx, cy]
        lt = anchor_points[:, None, :] - gt_bboxes[None, :, :2]  # (na, ng, 2)
        rb = gt_bboxes[None, :, 2:] - anchor_points[:, None, :]  # (na, ng, 2)
        in_gt = torch.cat([lt, rb], dim=-1).min(dim=-1).values > 0  # (na, ng)
        
        # Compute alignment metric
        # Get predicted class scores for GT classes
        gt_cls_scores = pred_scores[:, gt_labels]  # (na, ng)
        
        # Compute pairwise IoU
        pair_iou = compute_iou(pred_bboxes, gt_bboxes, mode='iou')  # (na, ng)
        pair_iou = pair_iou.clamp(0, 1)
        
        # Alignment metric: score^alpha * iou^beta
        alignment = gt_cls_scores.pow(self.alpha) * pair_iou.pow(self.beta)
        alignment[~in_gt] = 0  # mask out anchors not inside GT
        
        # Select top-k anchors per GT
        topk_mask = torch.zeros_like(alignment, dtype=torch.bool)
        for j in range(num_gt):
            vals = alignment[:, j]
            k = min(self.topk, (vals > 0).sum().item())
            if k > 0:
                _, topk_idx = vals.topk(k)
                topk_mask[topk_idx, j] = True
        
        alignment[~topk_mask] = 0
        
        # Resolve conflicts: each anchor -> highest alignment GT
        assigned_gt = alignment.argmax(dim=1)  # (na,)
        max_alignment = alignment.max(dim=1).values  # (na,)
        
        # Background mask
        bg_mask = max_alignment < self.eps
        assigned_gt[bg_mask] = -1
        
        # Build outputs
        assigned_labels = torch.where(
            bg_mask,
            torch.tensor(-1, device=device),
            gt_labels[assigned_gt.clamp(min=0)]
        )
        assigned_labels[bg_mask] = -1
        assigned_bboxes = torch.zeros((num_anchors, 4), device=device)
        assigned_bboxes[~bg_mask] = gt_bboxes[assigned_gt[~bg_mask]]
        
        # Soft label targets (normalized alignment score)
        assigned_scores = torch.zeros((num_anchors, num_classes), device=device)
        fg_mask = ~bg_mask
        if fg_mask.any():
            norm_align = max_alignment[fg_mask] / (max_alignment[fg_mask].max() + self.eps)
            assigned_scores[fg_mask, assigned_labels[fg_mask]] = norm_align
        
        return assigned_labels, assigned_bboxes, assigned_scores

print("TaskAlignedAssigner ready.")
TaskAlignedAssigner ready.

Composite loss function

YOLOv11’s loss function combines three complementary objectives:
  1. Classification loss (BCE with soft labels): Binary cross-entropy between predicted class logits and the soft label targets produced by TAL. Soft labels (values between 0 and 1 based on alignment quality) provide richer supervisory signal than hard 0/1 labels.
  2. Box regression loss (CIoU): Lbox=1CIoU(b^,b)\mathcal{L}_{box} = 1 - \text{CIoU}(\hat{b}, b^*), applied only to foreground (assigned) anchors. CIoU captures overlap, center distance, and aspect ratio in a single loss term.
  3. Distribution Focal Loss (DFL): Instead of directly regressing LTRB offsets, the DFL head predicts a discrete probability distribution over integer bins {0,1,,reg_max1}\{0, 1, \ldots, \text{reg\_max}-1\}. The DFL loss is a weighted cross-entropy between adjacent bins:
LDFL(Si,Si+1,y)=(1(yi))log(Si)(yi)log(Si+1)\mathcal{L}_{DFL}(S_i, S_{i+1}, y) = -(1 - (y - i)) \log(S_i) - (y - i) \log(S_{i+1}) where yy is the continuous target offset, i=yi = \lfloor y \rfloor, and SiS_i is the softmax probability for bin ii. The total loss is a weighted sum: L=λclsLcls+λboxLbox+λdflLdfl\mathcal{L} = \lambda_{cls} \cdot \mathcal{L}_{cls} + \lambda_{box} \cdot \mathcal{L}_{box} + \lambda_{dfl} \cdot \mathcal{L}_{dfl} with default weights λcls=0.5\lambda_{cls} = 0.5, λbox=7.5\lambda_{box} = 7.5, λdfl=1.5\lambda_{dfl} = 1.5. The high box weight reflects the importance of precise localization in object detection.
class YOLOv11Loss(nn.Module):
    """Composite loss for YOLOv11 training.
    
    Components:
        1. Classification: BCE with soft labels from TAL
        2. Box regression: CIoU loss
        3. Distribution Focal Loss: cross-entropy on DFL bins
    """
    
    def __init__(self, num_classes: int = 80, reg_max: int = 16, strides: List[int] = [8, 16, 32],
                 cls_weight: float = 0.5, box_weight: float = 7.5, dfl_weight: float = 1.5):
        super().__init__()
        self.num_classes = num_classes
        self.reg_max = reg_max
        self.strides = strides
        self.cls_weight = cls_weight
        self.box_weight = box_weight
        self.dfl_weight = dfl_weight
        self.assigner = TaskAlignedAssigner()
        self.bce = nn.BCEWithLogitsLoss(reduction='none')
    
    def _make_anchor_points(self, feat_sizes, device):
        """Generate anchor points for all feature levels."""
        all_points = []
        all_strides = []
        for (h, w), stride in zip(feat_sizes, self.strides):
            sy, sx = torch.meshgrid(torch.arange(h), torch.arange(w), indexing='ij')
            points = torch.stack([sx.flatten(), sy.flatten()], dim=-1).float()
            points = (points + 0.5) * stride  # center of each cell in image coords
            all_points.append(points)
            all_strides.append(torch.full((h * w,), stride, dtype=torch.float32))
        return torch.cat(all_points).to(device), torch.cat(all_strides).to(device)
    
    def _decode_boxes(self, box_pred, anchor_points, strides):
        """Decode LTRB offsets to x1y1x2y2 boxes."""
        lt = box_pred[:, :2] * strides.unsqueeze(-1)
        rb = box_pred[:, 2:] * strides.unsqueeze(-1)
        x1y1 = anchor_points - lt
        x2y2 = anchor_points + rb
        return torch.cat([x1y1, x2y2], dim=-1)
    
    def forward(self, predictions, gt_boxes_list, gt_labels_list):
        """
        Args:
            predictions: list of (cls_pred, box_pred, box_raw) per scale
            gt_boxes_list: list of (num_gt, 4) per image, normalized [cx, cy, w, h]
            gt_labels_list: list of (num_gt,) per image
        """
        device = predictions[0][0].device
        batch_size = predictions[0][0].shape[0]
        
        feat_sizes = [(p[0].shape[2], p[0].shape[3]) for p in predictions]
        anchor_points, anchor_strides = self._make_anchor_points(feat_sizes, device)
        
        # Concatenate predictions across scales
        all_cls = torch.cat([p[0].flatten(2).permute(0, 2, 1) for p in predictions], dim=1)
        all_box = torch.cat([p[1].flatten(2).permute(0, 2, 1) for p in predictions], dim=1)
        all_raw = torch.cat([p[2].flatten(2).permute(0, 2, 1) for p in predictions], dim=1)
        
        total_cls_loss = torch.tensor(0.0, device=device)
        total_box_loss = torch.tensor(0.0, device=device)
        total_dfl_loss = torch.tensor(0.0, device=device)
        num_pos = 0
        
        for b in range(batch_size):
            cls_pred = all_cls[b].sigmoid()  # (num_anchors, num_classes)
            box_pred = all_box[b]            # (num_anchors, 4) LTRB
            raw_pred = all_raw[b]            # (num_anchors, 4*reg_max)
            
            # Decode predicted boxes
            pred_bboxes = self._decode_boxes(box_pred, anchor_points, anchor_strides)
            
            gt_boxes = gt_boxes_list[b]
            gt_labels = gt_labels_list[b]
            
            if len(gt_boxes) == 0:
                total_cls_loss += self.bce(all_cls[b], torch.zeros_like(all_cls[b])).sum()
                continue
            
            # Convert GT from [cx, cy, w, h] normalized to [x1, y1, x2, y2] pixel
            gt_xyxy = torch.zeros_like(gt_boxes)
            gt_xyxy[:, 0] = (gt_boxes[:, 0] - gt_boxes[:, 2] / 2) * 640
            gt_xyxy[:, 1] = (gt_boxes[:, 1] - gt_boxes[:, 3] / 2) * 640
            gt_xyxy[:, 2] = (gt_boxes[:, 0] + gt_boxes[:, 2] / 2) * 640
            gt_xyxy[:, 3] = (gt_boxes[:, 1] + gt_boxes[:, 3] / 2) * 640
            
            # Task-aligned assignment
            assigned_labels, assigned_bboxes, assigned_scores = self.assigner.assign(
                cls_pred, pred_bboxes, gt_labels.long(), gt_xyxy,
                anchor_points, anchor_strides
            )
            
            fg_mask = assigned_labels >= 0
            num_fg = fg_mask.sum().item()
            num_pos += num_fg
            
            # Classification loss (BCE with soft labels)
            cls_targets = assigned_scores.to(device)
            total_cls_loss += self.bce(all_cls[b], cls_targets).sum()
            
            if num_fg > 0:
                # Box loss (CIoU)
                fg_pred_boxes = pred_bboxes[fg_mask]
                fg_gt_boxes = assigned_bboxes[fg_mask].to(device)
                ciou = compute_iou(fg_pred_boxes, fg_gt_boxes, mode='ciou')
                ciou_diag = torch.diag(ciou)
                total_box_loss += (1.0 - ciou_diag).sum()
                
                # DFL loss
                fg_raw = raw_pred[fg_mask]  # (num_fg, 4*reg_max)
                fg_raw = fg_raw.view(-1, self.reg_max)  # (num_fg*4, reg_max)
                # Target: continuous LTRB offsets
                fg_target_ltrb = torch.zeros((num_fg, 4), device=device)
                fg_target_ltrb[:, :2] = (anchor_points[fg_mask] - fg_gt_boxes[:, :2]) / anchor_strides[fg_mask].unsqueeze(-1)
                fg_target_ltrb[:, 2:] = (fg_gt_boxes[:, 2:] - anchor_points[fg_mask]) / anchor_strides[fg_mask].unsqueeze(-1)
                fg_target_ltrb = fg_target_ltrb.clamp(0, self.reg_max - 1 - 0.01)
                target_flat = fg_target_ltrb.view(-1)
                # DFL: cross-entropy between adjacent integer bins
                target_left = target_flat.long()
                target_right = target_left + 1
                weight_right = target_flat - target_left.float()
                weight_left = 1.0 - weight_right
                dfl_loss = (
                    F.cross_entropy(fg_raw, target_left, reduction='none') * weight_left +
                    F.cross_entropy(fg_raw, target_right.clamp(max=self.reg_max - 1), reduction='none') * weight_right
                )
                total_dfl_loss += dfl_loss.sum()
        
        num_pos = max(num_pos, 1)
        loss_cls = self.cls_weight * total_cls_loss / num_pos
        loss_box = self.box_weight * total_box_loss / num_pos
        loss_dfl = self.dfl_weight * total_dfl_loss / num_pos
        total_loss = loss_cls + loss_box + loss_dfl
        
        return total_loss, {
            'cls_loss': loss_cls.item(),
            'box_loss': loss_box.item(),
            'dfl_loss': loss_dfl.item(),
            'total_loss': total_loss.item(),
            'num_pos': num_pos
        }

print("YOLOv11Loss ready.")
YOLOv11Loss ready.

Real COCO data for training

Instead of training on synthetic colored rectangles, we stream real COCO images from detection-datasets/coco on the Hugging Face Hub. We buffer 32 images in memory for this demo to keep training fast while using real-world data.
Data source: Images streamed from detection-datasets/coco. See our HF COCO streaming tutorial for details.
class COCOStreamDetectionDataset(torch.utils.data.Dataset):
    """Buffer real COCO images from HF streaming for detection training.

    Pre-fetches max_samples images via streaming and stores them in memory,
    providing random access and len() support for the DataLoader.
    Annotations are converted to YOLO format [cx, cy, w, h] normalized.
    """

    def __init__(self, split='train', max_samples=32, img_size=640, num_classes=80):
        self.img_size = img_size
        self.num_classes = num_classes
        self.data = []

        print(f"Streaming {max_samples} COCO images from Hugging Face...")
        ds = load_dataset('detection-datasets/coco', split=split, streaming=True)

        for example in ds:
            if len(self.data) >= max_samples:
                break

            img_pil = example['image'].convert('RGB')
            img_np = np.array(img_pil)
            h, w = img_np.shape[:2]

            bboxes = example['objects']['bbox']
            cats = example['objects']['category']

            boxes = []
            labels = []
            for bbox, cat_id in zip(bboxes, cats):
                bx, by, bw, bh = bbox
                if bw <= 0 or bh <= 0:
                    continue
                cx = (bx + bw / 2) / w
                cy = (by + bh / 2) / h
                boxes.append([cx, cy, bw / w, bh / h])
                labels.append(int(cat_id))

            if len(boxes) == 0:
                continue

            # Resize to model input size
            img_resized = np.array(img_pil.resize((self.img_size, self.img_size)))
            img_tensor = torch.from_numpy(img_resized).permute(2, 0, 1).float() / 255.0

            boxes_t = torch.tensor(boxes, dtype=torch.float32)
            labels_t = torch.tensor(labels, dtype=torch.long)

            self.data.append((img_tensor, boxes_t, labels_t))

        print(f"Buffered {len(self.data)} COCO images")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


def collate_fn(batch):
    """Custom collate: images are stacked, boxes/labels stay as lists."""
    imgs, boxes, labels = zip(*batch)
    imgs = torch.stack(imgs, dim=0)
    targets = torch.zeros(len(imgs))
    return imgs, targets, list(boxes), list(labels)


# Create dataset and dataloader with real COCO images
dataset = COCOStreamDetectionDataset(max_samples=32, num_classes=80)
loader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)

# Visualize samples
fig, axes = plt.subplots(1, 4, figsize=(20, 5))
for i, ax in enumerate(axes):
    img, boxes, labels = dataset[i]
    ax.imshow(img.permute(1, 2, 0).numpy())
    ax.set_title(f'Image {i}: {len(boxes)} objects')
    ax.axis('off')
    for box in boxes:
        cx, cy, w, h = box.numpy() * 640
        rect = plt.Rectangle((cx - w/2, cy - h/2), w, h,
                             linewidth=2, edgecolor='white', facecolor='none')
        ax.add_patch(rect)
plt.suptitle('Real COCO Training Data', fontsize=14)
plt.tight_layout()
plt.show()

print(f"Dataset: {len(dataset)} images, DataLoader: {len(loader)} batches")
Streaming 32 COCO images from Hugging Face...
Warning: You are sending unauthenticated requests to the HF Hub. Please set a HF_TOKEN to enable higher rate limits and faster downloads.
Buffered 32 COCO images
Output from cell 7
Dataset: 32 images, DataLoader: 8 batches

Training loop

The training loop follows a standard PyTorch pattern with a few detection-specific details:
  • Gradient clipping (max_norm=10.0) prevents exploding gradients, which can occur when the loss components have very different magnitudes early in training.
  • We track individual loss components (classification, box, DFL) to diagnose training behavior.
  • The number of positive (foreground) assignments per batch is logged to ensure the assigner is working correctly.
def train_one_epoch(model, dataloader, optimizer, loss_fn, device, epoch):
    model.train()
    epoch_losses = {'cls_loss': 0, 'box_loss': 0, 'dfl_loss': 0, 'total_loss': 0}
    
    for batch_idx, (imgs, targets, boxes_list, labels_list) in enumerate(dataloader):
        imgs = imgs.to(device)
        
        # Move GT to device
        gt_boxes = [b.to(device) for b in boxes_list]
        gt_labels = [l.to(device) for l in labels_list]
        
        # Forward
        predictions = model(imgs)
        loss, loss_dict = loss_fn(predictions, gt_boxes, gt_labels)
        
        # Backward
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0)
        optimizer.step()
        
        for k in epoch_losses:
            epoch_losses[k] += loss_dict[k]
        
        if batch_idx % 2 == 0:
            print(f"  Batch {batch_idx}: loss={loss_dict['total_loss']:.4f} "
                  f"(cls={loss_dict['cls_loss']:.4f}, box={loss_dict['box_loss']:.4f}, "
                  f"dfl={loss_dict['dfl_loss']:.4f}, pos={loss_dict['num_pos']})")
    
    n = len(dataloader)
    return {k: v / n for k, v in epoch_losses.items()}

Running the training demo

We train for 5 epochs on our tiny synthetic dataset. The goal is not to achieve good detection performance (that requires real data and many more epochs), but to verify that:
  1. The forward pass produces valid predictions at all three scales
  2. The TAL assigner finds positive anchors for the ground-truth boxes
  3. All three loss components produce valid gradients
  4. The total loss decreases over training
We use AdamW with cosine annealing, which is standard for YOLO training.
# Small-scale training demo with real COCO data
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

model = YOLOv11(num_classes=80).to(device)
loss_fn = YOLOv11Loss(num_classes=80)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.05)

# Cosine LR scheduler
num_epochs = 5
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Training for {num_epochs} epochs on {len(dataset)} real COCO images\n")

history = []
for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs} (lr={scheduler.get_last_lr()[0]:.6f})")
    epoch_loss = train_one_epoch(model, loader, optimizer, loss_fn, device, epoch)
    scheduler.step()
    history.append(epoch_loss)
    print(f"  -> Avg loss: {epoch_loss['total_loss']:.4f}")

print("\nTraining complete!")
Using device: cuda
Model parameters: 110,212,128
Training for 5 epochs on 32 real COCO images


Epoch 1/5 (lr=0.001000)
  Batch 0: loss=4213.8818 (cls=4193.7261, box=3.3395, dfl=16.8166, pos=227)
  Batch 2: loss=4270.5688 (cls=4249.9336, box=4.0919, dfl=16.5435, pos=207)
  Batch 4: loss=2912.0315 (cls=2890.2510, box=3.7757, dfl=18.0049, pos=269)
  Batch 6: loss=2059.2117 (cls=2040.1794, box=3.4031, dfl=15.6291, pos=340)
  -> Avg loss: 3492.3172

Epoch 2/5 (lr=0.000905)
  Batch 0: loss=4035.7471 (cls=4017.5625, box=3.4557, dfl=14.7287, pos=152)
  Batch 2: loss=2407.4922 (cls=2388.5825, box=3.5486, dfl=15.3612, pos=231)
  Batch 4: loss=1671.8571 (cls=1652.5510, box=3.6455, dfl=15.6606, pos=302)
  Batch 6: loss=2806.7612 (cls=2788.8137, box=3.1957, dfl=14.7517, pos=161)
  -> Avg loss: 2632.2373

Epoch 3/5 (lr=0.000655)
  Batch 0: loss=5208.6025 (cls=5191.1230, box=3.4236, dfl=14.0559, pos=78)
  Batch 2: loss=1585.1288 (cls=1566.7739, box=3.3842, dfl=14.9707, pos=238)
  Batch 4: loss=1801.5144 (cls=1783.7823, box=2.9455, dfl=14.7865, pos=196)
  Batch 6: loss=1900.2883 (cls=1883.4437, box=3.3429, dfl=13.5017, pos=172)
  -> Avg loss: 1822.9444

Epoch 4/5 (lr=0.000345)
  Batch 0: loss=1511.5569 (cls=1493.6290, box=3.3001, dfl=14.6277, pos=201)
  Batch 2: loss=809.4685 (cls=792.3691, box=3.1608, dfl=13.9386, pos=363)
  Batch 4: loss=1549.5906 (cls=1532.0712, box=3.5075, dfl=14.0120, pos=180)
  Batch 6: loss=740.3965 (cls=723.1921, box=3.6049, dfl=13.5995, pos=368)
  -> Avg loss: 1255.7059

Epoch 5/5 (lr=0.000095)
  Batch 0: loss=821.8530 (cls=804.5747, box=3.4525, dfl=13.8258, pos=318)
  Batch 2: loss=1521.7334 (cls=1505.7045, box=3.0240, dfl=13.0050, pos=165)
  Batch 4: loss=1769.9384 (cls=1753.2113, box=3.3225, dfl=13.4045, pos=140)
  Batch 6: loss=781.3149 (cls=765.2618, box=2.9836, dfl=13.0695, pos=321)
  -> Avg loss: 1169.2647

Training complete!

Loss curves visualization

Plotting the individual loss components over training helps diagnose issues:
  • Classification loss should decrease as the model learns to distinguish object classes from background
  • Box (CIoU) loss should decrease as predicted boxes align better with ground truth
  • DFL loss should decrease as the distribution predictions sharpen around the correct offsets
If one component plateaus while others decrease, it may indicate an imbalance in loss weights.
fig, axes = plt.subplots(1, 4, figsize=(20, 4))
keys = ['total_loss', 'cls_loss', 'box_loss', 'dfl_loss']
titles = ['Total Loss', 'Classification Loss', 'Box (CIoU) Loss', 'DFL Loss']

for ax, key, title in zip(axes, keys, titles):
    values = [h[key] for h in history]
    ax.plot(range(1, len(values)+1), values, 'b-o', linewidth=2)
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Loss')
    ax.set_title(title)
    ax.grid(True, alpha=0.3)

plt.suptitle('Training Loss Curves', fontsize=14)
plt.tight_layout()
plt.show()
Output from cell 10

Summary

In this notebook, we built the complete training pipeline for YOLOv11 from scratch. Here is a recap of the key components:
  1. CIoU provides complete geometric alignment by combining overlap area, center-point distance, and aspect ratio consistency into a single differentiable metric. This gives the optimizer rich gradient information for box regression, unlike basic IoU which produces zero gradients for non-overlapping boxes.
  2. Task-Aligned Learning (TAL) assigns ground-truth boxes to anchor points based on both classification confidence and localization quality. The alignment metric t=sαuβt = s^\alpha \cdot u^\beta ensures that labels are assigned to anchors where the model is already performing well, creating a virtuous cycle during training.
  3. Distribution Focal Loss (DFL) enables precise box regression by predicting a probability distribution over discrete offset bins rather than a single scalar. The weighted cross-entropy between adjacent bins preserves the continuous nature of the target.
  4. The composite loss balances classification (λ=0.5\lambda = 0.5), localization (λ=7.5\lambda = 7.5), and distribution quality (λ=1.5\lambda = 1.5). The heavy weight on box regression reflects the critical importance of precise localization in object detection.

Next steps

In Notebook 5, we will build the inference pipeline: decoding predictions into bounding boxes, applying Non-Maximum Suppression (NMS), and evaluating detection quality using COCO metrics (mAP, AP50, AP75).