class ConvBNSiLU(nn.Module):
def __init__(self, c_in, c_out, k=1, s=1, p=None, g=1):
super().__init__()
if p is None: p = k // 2
self.conv = nn.Conv2d(c_in, c_out, k, s, p, groups=g, bias=False)
self.bn = nn.BatchNorm2d(c_out)
self.act = nn.SiLU(inplace=True)
def forward(self, x): return self.act(self.bn(self.conv(x)))
class Bottleneck(nn.Module):
def __init__(self, c_in, c_out, shortcut=True, k=(3,3), e=0.5):
super().__init__()
c_hid = int(c_out * e)
self.cv1 = ConvBNSiLU(c_in, c_hid, k[0])
self.cv2 = ConvBNSiLU(c_hid, c_out, k[1])
self.add = shortcut and c_in == c_out
def forward(self, x): return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
class C3k2(nn.Module):
def __init__(self, c_in, c_out, n=1, shortcut=True, e=0.5):
super().__init__()
self.c = int(c_out * e)
self.cv1 = ConvBNSiLU(c_in, 2 * self.c, 1)
self.cv2 = ConvBNSiLU((2 + n) * self.c, c_out, 1)
self.bottlenecks = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, (3,3), 1.0) for _ in range(n))
def forward(self, x):
y = list(self.cv1(x).chunk(2, dim=1))
for bn in self.bottlenecks: y.append(bn(y[-1]))
return self.cv2(torch.cat(y, dim=1))
class SPPF(nn.Module):
def __init__(self, c_in, c_out, k=5):
super().__init__()
c_hid = c_in // 2
self.cv1 = ConvBNSiLU(c_in, c_hid, 1)
self.cv2 = ConvBNSiLU(c_hid * 4, c_out, 1)
self.pool = nn.MaxPool2d(k, stride=1, padding=k // 2)
def forward(self, x):
x = self.cv1(x); y1 = self.pool(x); y2 = self.pool(y1); y3 = self.pool(y2)
return self.cv2(torch.cat([x, y1, y2, y3], dim=1))
class YOLOv11Backbone(nn.Module):
def __init__(self, c_in=3, base=64):
super().__init__()
c1, c2, c3, c4, c5 = base, base*2, base*4, base*8, base*16
self.stem = ConvBNSiLU(c_in, c1, 3, s=2)
self.s1_down = ConvBNSiLU(c1, c2, 3, s=2); self.s1_c3k2 = C3k2(c2, c2, n=2)
self.s2_down = ConvBNSiLU(c2, c3, 3, s=2); self.s2_c3k2 = C3k2(c3, c3, n=2)
self.s3_down = ConvBNSiLU(c3, c4, 3, s=2); self.s3_c3k2 = C3k2(c4, c4, n=2)
self.s4_down = ConvBNSiLU(c4, c5, 3, s=2); self.s4_c3k2 = C3k2(c5, c5, n=2)
self.sppf = SPPF(c5, c5)
def forward(self, x):
x = self.stem(x)
x = self.s1_c3k2(self.s1_down(x))
p3 = self.s2_c3k2(self.s2_down(x))
p4 = self.s3_c3k2(self.s3_down(p3))
p5 = self.sppf(self.s4_c3k2(self.s4_down(p4)))
return p3, p4, p5
class FPN(nn.Module):
def __init__(self, c3=256, c4=512, c5=1024):
super().__init__()
self.lateral5 = ConvBNSiLU(c5, c4, 1)
self.lateral4 = ConvBNSiLU(c4, c3, 1)
self.up = nn.Upsample(scale_factor=2, mode='nearest')
self.fuse4 = C3k2(c4 * 2, c4, n=2)
self.fuse3 = C3k2(c3 * 2, c3, n=2)
def forward(self, p3, p4, p5):
p5_up = self.up(self.lateral5(p5))
p4 = self.fuse4(torch.cat([p4, p5_up], dim=1))
p4_up = self.up(self.lateral4(p4))
p3 = self.fuse3(torch.cat([p3, p4_up], dim=1))
return p3, p4, p5
class PAN(nn.Module):
def __init__(self, c3=256, c4=512, c5=1024):
super().__init__()
self.down3 = ConvBNSiLU(c3, c3, 3, s=2)
self.fuse4 = C3k2(c3 + c4, c4, n=2)
self.down4 = ConvBNSiLU(c4, c4, 3, s=2)
self.fuse5 = C3k2(c4 + c5, c5, n=2)
def forward(self, p3, p4, p5):
p4 = self.fuse4(torch.cat([self.down3(p3), p4], dim=1))
p5 = self.fuse5(torch.cat([self.down4(p4), p5], dim=1))
return p3, p4, p5
class C2PSA(nn.Module):
def __init__(self, c, n_heads=8):
super().__init__()
self.cv1 = ConvBNSiLU(c, c, 1)
self.attn = nn.MultiheadAttention(c, n_heads, batch_first=True)
self.ffn = nn.Sequential(ConvBNSiLU(c, c * 2, 1), ConvBNSiLU(c * 2, c, 1))
self.cv2 = ConvBNSiLU(c, c, 1)
def forward(self, x):
y = self.cv1(x)
B, C, H, W = y.shape
flat = y.flatten(2).permute(0, 2, 1)
flat = flat + self.attn(flat, flat, flat, need_weights=False)[0]
y = flat.permute(0, 2, 1).view(B, C, H, W)
y = y + self.ffn(y)
return self.cv2(y)
class DFLHead(nn.Module):
def __init__(self, c_in, num_classes=80, reg_max=16):
super().__init__()
self.reg_max = reg_max
self.num_classes = num_classes
self.cls_convs = nn.Sequential(ConvBNSiLU(c_in, c_in, 3), ConvBNSiLU(c_in, c_in, 3))
self.reg_convs = nn.Sequential(ConvBNSiLU(c_in, c_in, 3), ConvBNSiLU(c_in, c_in, 3))
self.cls_pred = nn.Conv2d(c_in, num_classes, 1)
self.reg_pred = nn.Conv2d(c_in, 4 * reg_max, 1)
self.proj = nn.Parameter(torch.arange(reg_max, dtype=torch.float32), requires_grad=False)
def forward(self, x):
cls_out = self.cls_pred(self.cls_convs(x))
reg_raw = self.reg_pred(self.reg_convs(x))
B, _, H, W = reg_raw.shape
reg_dist = reg_raw.view(B, 4, self.reg_max, H, W)
reg_box = F.softmax(reg_dist, dim=2)
reg_box = (reg_box * self.proj.view(1, 1, -1, 1, 1)).sum(dim=2)
return cls_out, reg_box, reg_raw
class DetectionHead(nn.Module):
def __init__(self, channels=[256, 512, 1024], num_classes=80, reg_max=16):
super().__init__()
self.heads = nn.ModuleList([DFLHead(c, num_classes, reg_max) for c in channels])
def forward(self, features):
return [head(f) for head, f in zip(self.heads, features)]
class YOLOv11(nn.Module):
def __init__(self, num_classes=80, reg_max=16):
super().__init__()
self.backbone = YOLOv11Backbone()
self.fpn = FPN()
self.pan = PAN()
self.c2psa = C2PSA(1024)
self.head = DetectionHead(num_classes=num_classes, reg_max=reg_max)
def forward(self, x):
p3, p4, p5 = self.backbone(x)
p5 = self.c2psa(p5)
p3, p4, p5 = self.fpn(p3, p4, p5)
p3, p4, p5 = self.pan(p3, p4, p5)
return self.head([p3, p4, p5])
print("Model components loaded successfully.")
print(f"YOLOv11 parameters: {sum(p.numel() for p in YOLOv11(num_classes=80).parameters()):,}")