class ResBlock(nn.Module):
"""Basic residual block (two 3x3 convs) with optional batch normalization."""
def __init__(self, in_channels, out_channels, stride=1, use_bn=True):
super().__init__()
self.use_bn = use_bn
self.conv1 = nn.Conv2d(in_channels, out_channels, 3,
stride=stride, padding=1, bias=not use_bn)
self.bn1 = nn.BatchNorm2d(out_channels) if use_bn else nn.Identity()
self.conv2 = nn.Conv2d(out_channels, out_channels, 3,
stride=1, padding=1, bias=not use_bn)
self.bn2 = nn.BatchNorm2d(out_channels) if use_bn else nn.Identity()
self.relu = nn.ReLU(inplace=True)
# Projection shortcut when dimensions change
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels:
layers = [nn.Conv2d(in_channels, out_channels, 1, stride=stride, bias=not use_bn)]
if use_bn:
layers.append(nn.BatchNorm2d(out_channels))
self.shortcut = nn.Sequential(*layers)
def forward(self, x):
out = self.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out = self.relu(out + self.shortcut(x))
return out
class SmallResNet(nn.Module):
"""Small ResNet for CIFAR-10 (6 residual blocks, 3 stages)."""
def __init__(self, use_bn=True, num_classes=10):
super().__init__()
self.use_bn = use_bn
self.stem = nn.Sequential(
nn.Conv2d(3, 16, 3, stride=1, padding=1, bias=not use_bn),
nn.BatchNorm2d(16) if use_bn else nn.Identity(),
nn.ReLU(inplace=True),
)
self.layer1 = self._make_layer(16, 16, 2, stride=1, use_bn=use_bn)
self.layer2 = self._make_layer(16, 32, 2, stride=2, use_bn=use_bn)
self.layer3 = self._make_layer(32, 64, 2, stride=2, use_bn=use_bn)
self.pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(64, num_classes)
@staticmethod
def _make_layer(in_ch, out_ch, n_blocks, stride, use_bn):
layers = [ResBlock(in_ch, out_ch, stride=stride, use_bn=use_bn)]
for _ in range(1, n_blocks):
layers.append(ResBlock(out_ch, out_ch, stride=1, use_bn=use_bn))
return nn.Sequential(*layers)
def forward(self, x):
x = self.stem(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.pool(x).flatten(1)
return self.fc(x)
# Count parameters
model_bn = SmallResNet(use_bn=True).to(DEVICE)
model_no_bn = SmallResNet(use_bn=False).to(DEVICE)
n_params = sum(p.numel() for p in model_bn.parameters())
print(f'SmallResNet parameters: {n_params:,}')