PyTorch实战对比:线性网络、CNN与ViT在手写数字识别上的性能较量
手写数字识别是计算机视觉领域的"Hello World"任务,MNIST数据集自1998年发布以来,已成为深度学习入门的经典基准。本文将使用PyTorch实现三种不同架构:线性神经网络、卷积神经网络(CNN)和Vision Transformer(ViT),并对比它们的性能表现。
环境准备
首先确保已安装PyTorch和torchvision:
pip install torch torchvision matplotlib
数据加载与预处理
三种方法使用相同的数据预处理流程:将图像转换为张量并归一化到[-1, 1]范围。
import torch
import torch.utils.data as data
from torchvision import datasets, transforms
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
train_dataset = datasets.MNIST(
root="./data", train=True, transform=transform, download=True
)
test_dataset = datasets.MNIST(
root="./data", train=False, transform=transform, download=True
)
train_loader = data.DataLoader(train_dataset, batch_size=100, shuffle=True)
test_loader = data.DataLoader(test_dataset, batch_size=100, shuffle=False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
方法一:线性神经网络(准确率:98.08%)
线性神经网络是最简单的深度学习模型。我们将28×28的图像展平为784维向量,经过两个全连接层进行分类。
import torch.nn as nn
import torch.optim as optim
class LinearModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 1000)
self.relu = nn.ReLU()
self.dropout = nn.Dropout()
self.fc2 = nn.Linear(1000, 10)
def forward(self, x):
out = self.fc1(x)
out = self.relu(out)
out = self.dropout(out)
out = self.fc2(out)
return out
model = LinearModel().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 训练20个epoch
for epoch in range(20):
for i, (images, labels) in enumerate(train_loader):
images = images.to(device).reshape(-1, 784)
labels = labels.to(device)
output = model(images)
loss = criterion(output, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (i + 1) % 100 == 0:
print(f"Epoch [{epoch+1}/20], Loss: {loss.item():.4f}")
网络结构特点:
- 只有全连接层,没有卷积操作
- 无法利用像素的空间局部性
- 参数量较大(约80万)
方法二:卷积神经网络(准确率:99.28%)
CNN通过卷积核提取图像的局部特征,是图像任务的经典选择。
class CNNModel(nn.Module):
def __init__(self):
super().__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=5, padding=2),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
)
self.layer2 = nn.Sequential(
nn.Conv2d(32, 64, kernel_size=5, padding=2),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
)
self.dropout = nn.Dropout()
self.fc1 = nn.Linear(7 * 7 * 64, 1000)
self.fc2 = nn.Linear(1000, 10)
def forward(self, x):
out = self.layer1(x)
out = self.layer2(out)
out = out.reshape(out.size(0), -1)
out = self.dropout(out)
out = self.fc1(out)
out = self.fc2(out)
return out
网络结构特点:
- 两层卷积-激活-池化结构
- 卷积层后特征图尺寸:28×28 → 14×14 → 7×7
- 利用参数共享,大幅减少参数量
方法三:Vision Transformer(准确率:98.30%)
ViT将Transformer架构引入计算机视觉,通过自注意力机制捕捉全局依赖。
class ViTModel(nn.Module):
def __init__(self, image_size=28, patch_size=7, in_channels=1,
d_model=128, n_head=4, num_layers=3, num_classes=10):
super().__init__()
self.num_patches = (image_size // patch_size) ** 2
# 将图像分割为patches并线性投影
self.patch_embed = nn.Conv2d(
in_channels, d_model, kernel_size=patch_size, stride=patch_size
)
# CLS token和位置编码
self.cls_token = nn.Parameter(torch.randn(1, 1, d_model))
self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches + 1, d_model))
# Transformer编码器
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model, nhead=n_head, dim_feedforward=4 * d_model,
activation="gelu", batch_first=True, dropout=0.1
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
# 分类头
self.mlp_head = nn.Sequential(
nn.LayerNorm(d_model), nn.Linear(d_model, num_classes)
)
def forward(self, x):
batch_size = x.shape[0]
# 获取patch embeddings
x = self.patch_embed(x)
x = x.flatten(2).permute(0, 2, 1)
# 添加CLS token和位置编码
cls_tokens = self.cls_token.repeat(batch_size, 1, 1)
x = torch.cat([cls_tokens, x], dim=1)
x += self.pos_embed
# Transformer编码
x = self.transformer(x)
# 使用CLS token进行分类
return self.mlp_head(x[:, 0])
# 使用AdamW优化器,添加权重衰减
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
网络结构特点:
- 将28×28图像分割为16个7×7的patches
- 使用自注意力机制建模全局关系
- 适合处理长距离依赖
性能对比
下表展示了三种方法在MNIST测试集上的准确率:
| 方法 | 5轮 | 10轮 | 15轮 | 20轮 |
|---|---|---|---|---|
| 线性神经网络 | 96.69% | 97.77% | 97.81% | 98.08% |
| 卷积神经网络 | 98.85% | 99.27% | 99.16% | 99.28% |
| Vision Transformer | 97.64% | 98.39% | 98.27% | 98.30% |
结果分析
CNN表现最佳:99.28%的准确率明显优于其他方法。卷积操作天然适合处理图像数据,能有效提取局部特征。
ViT表现中等:98.30%的准确率略高于线性网络,但低于CNN。原因是MNIST图像尺寸小(28×28),ViT将图像分割为patches后,可学习的patches数量有限,难以充分发挥Transformer的优势。ViT通常在ImageNet等大规模数据集上表现更好。
线性网络表现合格:98.08%的准确率虽然最低,但考虑到其结构简单、训练快速,对于MNIST这样相对简单的任务已经足够实用。
训练建议
过拟合观察:CNN在15轮后准确率略有下降(99.16% → 99.28%),可能是轻微过拟合。可考虑增加Dropout率或早停。
学习率调优:建议尝试学习率调度器(如StepLR或CosineAnnealingLR)提升收敛效果。
数据增强:可添加随机旋转、平移等增强操作进一步提升泛化能力。
总结
本文使用PyTorch实现了三种不同的手写数字识别模型。实验表明,对于MNIST数据集:
- CNN最适合,充分利用了图像的局部相关性
- ViT潜力大,但在小数据集上不如CNN
- 线性网络简单有效,适合快速原型验证