class ConvBNSiLU(nn.Module):
def __init__(self, in_c, out_c, k=1, s=1, p=None, g=1):
super().__init__()
if p is None:
p = k // 2
self.conv = nn.Conv2d(in_c, out_c, k, s, p, groups=g, bias=False)
self.bn = nn.BatchNorm2d(out_c)
self.act = nn.SiLU(inplace=True)
def forward(self, x):
return self.act(self.bn(self.conv(x)))
class Bottleneck(nn.Module):
def __init__(self, in_c, out_c, shortcut=True, e=0.5):
super().__init__()
mid = int(out_c * e)
self.cv1 = ConvBNSiLU(in_c, mid, 3)
self.cv2 = ConvBNSiLU(mid, out_c, 3)
self.add = shortcut and in_c == out_c
def forward(self, x):
y = self.cv2(self.cv1(x))
return x + y if self.add else y
class C3k2(nn.Module):
def __init__(self, in_c, out_c, n=2, shortcut=True, e=0.5):
super().__init__()
self.c = int(out_c * e)
self.cv1 = ConvBNSiLU(in_c, 2 * self.c, 1)
self.cv2 = ConvBNSiLU((2 + n) * self.c, out_c, 1)
self.bottlenecks = nn.ModuleList(
Bottleneck(self.c, self.c, shortcut) for _ in range(n)
)
def forward(self, x):
y = list(self.cv1(x).chunk(2, dim=1))
for bn in self.bottlenecks:
y.append(bn(y[-1]))
return self.cv2(torch.cat(y, dim=1))
class SPPF(nn.Module):
def __init__(self, in_c, out_c, k=5):
super().__init__()
mid = in_c // 2
self.cv1 = ConvBNSiLU(in_c, mid, 1)
self.cv2 = ConvBNSiLU(mid * 4, out_c, 1)
self.pool = nn.MaxPool2d(k, stride=1, padding=k // 2)
def forward(self, x):
x = self.cv1(x)
y1 = self.pool(x)
y2 = self.pool(y1)
y3 = self.pool(y2)
return self.cv2(torch.cat([x, y1, y2, y3], dim=1))
class YOLOv11Backbone(nn.Module):
def __init__(self, in_channels=3, base_channels=64):
super().__init__()
c1, c2, c3, c4, c5 = (
base_channels, base_channels * 2, base_channels * 4,
base_channels * 8, base_channels * 16,
)
self.stem = ConvBNSiLU(in_channels, c1, 3, 2)
self.stage1 = nn.Sequential(ConvBNSiLU(c1, c2, 3, 2), C3k2(c2, c2, n=2))
self.stage2 = nn.Sequential(ConvBNSiLU(c2, c3, 3, 2), C3k2(c3, c3, n=2))
self.stage3 = nn.Sequential(ConvBNSiLU(c3, c4, 3, 2), C3k2(c4, c4, n=2))
self.stage4 = nn.Sequential(ConvBNSiLU(c4, c5, 3, 2), C3k2(c5, c5, n=2), SPPF(c5, c5))
def forward(self, x):
x = self.stem(x)
x = self.stage1(x)
p3 = self.stage2(x)
p4 = self.stage3(p3)
p5 = self.stage4(p4)
return p3, p4, p5
class FPN(nn.Module):
def __init__(self, c3=256, c4=512, c5=1024):
super().__init__()
self.lateral_p5 = ConvBNSiLU(c5, c4, 1)
self.lateral_p4 = ConvBNSiLU(c4, c3, 1)
self.fpn_p4 = C3k2(c4 + c4, c4, n=2, shortcut=False)
self.fpn_p3 = C3k2(c3 + c3, c3, n=2, shortcut=False)
def forward(self, p3, p4, p5):
p5_up = F.interpolate(self.lateral_p5(p5), size=p4.shape[2:], mode='nearest')
fpn_p4 = self.fpn_p4(torch.cat([p5_up, p4], dim=1))
p4_up = F.interpolate(self.lateral_p4(fpn_p4), size=p3.shape[2:], mode='nearest')
fpn_p3 = self.fpn_p3(torch.cat([p4_up, p3], dim=1))
return fpn_p3, fpn_p4, p5
class PAN(nn.Module):
def __init__(self, c3=256, c4=512, c5=1024):
super().__init__()
self.down_p3 = ConvBNSiLU(c3, c3, 3, 2)
self.down_p4 = ConvBNSiLU(c4, c4, 3, 2)
self.pan_p4 = C3k2(c3 + c4, c4, n=2, shortcut=False)
self.pan_p5 = C3k2(c4 + c5, c5, n=2, shortcut=False)
def forward(self, fpn_p3, fpn_p4, p5):
p3_down = self.down_p3(fpn_p3)
pan_p4 = self.pan_p4(torch.cat([p3_down, fpn_p4], dim=1))
p4_down = self.down_p4(pan_p4)
pan_p5 = self.pan_p5(torch.cat([p4_down, p5], dim=1))
return fpn_p3, pan_p4, pan_p5
class C2PSA(nn.Module):
def __init__(self, in_channels, out_channels, n=1):
super().__init__()
self.c = in_channels // 2
self.cv1 = ConvBNSiLU(in_channels, 2 * self.c, 1)
self.cv2 = ConvBNSiLU(2 * self.c, out_channels, 1)
self.attention = nn.ModuleList([
nn.Sequential(
nn.AdaptiveAvgPool2d(1), nn.Flatten(),
nn.Linear(self.c, self.c // 4), nn.SiLU(inplace=True),
nn.Linear(self.c // 4, self.c), nn.Sigmoid()
) for _ in range(n)
])
self.bottlenecks = nn.ModuleList(
[Bottleneck(self.c, self.c, shortcut=True) for _ in range(n)]
)
def forward(self, x):
y = list(self.cv1(x).chunk(2, dim=1))
for attn, bn in zip(self.attention, self.bottlenecks):
feat = bn(y[-1])
att_weights = attn(feat).unsqueeze(-1).unsqueeze(-1)
feat = feat * att_weights
y.append(feat)
return self.cv2(torch.cat([y[0], y[-1]], dim=1))
class DFLHead(nn.Module):
def __init__(self, reg_max=16):
super().__init__()
self.reg_max = reg_max
self.register_buffer('project', torch.arange(reg_max, dtype=torch.float32))
def forward(self, x):
b, _, h, w = x.shape
x = x.view(b, 4, self.reg_max, h, w)
x = F.softmax(x, dim=2)
x = (x * self.project.view(1, 1, -1, 1, 1)).sum(dim=2)
return x
class DetectionHead(nn.Module):
def __init__(self, in_channels, num_classes=80, reg_max=16):
super().__init__()
self.num_classes = num_classes
self.reg_max = reg_max
self.cls_branch = nn.Sequential(
ConvBNSiLU(in_channels, in_channels, 3),
ConvBNSiLU(in_channels, in_channels, 3),
nn.Conv2d(in_channels, num_classes, 1)
)
self.reg_branch = nn.Sequential(
ConvBNSiLU(in_channels, in_channels, 3),
ConvBNSiLU(in_channels, in_channels, 3),
nn.Conv2d(in_channels, 4 * reg_max, 1)
)
self.dfl = DFLHead(reg_max)
def forward(self, x):
cls_pred = self.cls_branch(x)
box_raw = self.reg_branch(x)
box_pred = self.dfl(box_raw)
return cls_pred, box_pred, box_raw
class YOLOv11(nn.Module):
def __init__(self, num_classes=80, reg_max=16, base_channels=64):
super().__init__()
c3, c4, c5 = base_channels * 4, base_channels * 8, base_channels * 16
self.backbone = YOLOv11Backbone(base_channels=base_channels)
self.fpn = FPN(c3, c4, c5)
self.pan = PAN(c3, c4, c5)
self.c2psa = C2PSA(c5, c5, n=1)
self.head_p3 = DetectionHead(c3, num_classes, reg_max)
self.head_p4 = DetectionHead(c4, num_classes, reg_max)
self.head_p5 = DetectionHead(c5, num_classes, reg_max)
self.strides = [8, 16, 32]
self.num_classes = num_classes
self.reg_max = reg_max
def forward(self, x):
p3, p4, p5 = self.backbone(x)
fpn_p3, fpn_p4, fpn_p5 = self.fpn(p3, p4, p5)
pan_p3, pan_p4, pan_p5 = self.pan(fpn_p3, fpn_p4, fpn_p5)
pan_p5 = self.c2psa(pan_p5)
pred_p3 = self.head_p3(pan_p3)
pred_p4 = self.head_p4(pan_p4)
pred_p5 = self.head_p5(pan_p5)
return [pred_p3, pred_p4, pred_p5]