PyTorch实战对比:线性网络、CNN与ViT在手写数字识别上的性能较量

2024-06-21 人工智能 250 次阅读 0 次点赞
本文使用PyTorch实现了三种手写数字识别模型:线性神经网络(98.08%准确率)、卷积神经网络CNN(99.28%)和Vision Transformer ViT(98.30%)。实验表明,CNN在MNIST数据集上表现最佳,能有效提取图像局部特征;ViT虽可捕捉全局依赖,但受限于小图像尺寸,表现中等;线性网络结构简单,适合快速验证。文章还提供了数据加载、模型训练代码及性能对比,并建议通过增加Dropout、学习率调度或数据增强来优化泛化能力。

手写数字识别是计算机视觉领域的"Hello World"任务,MNIST数据集自1998年发布以来,已成为深度学习入门的经典基准。本文将使用PyTorch实现三种不同架构:线性神经网络、卷积神经网络(CNN)和Vision Transformer(ViT),并对比它们的性能表现。

环境准备

首先确保已安装PyTorch和torchvision:

pip install torch torchvision matplotlib

数据加载与预处理

三种方法使用相同的数据预处理流程:将图像转换为张量并归一化到[-1, 1]范围。

import torch
import torch.utils.data as data
from torchvision import datasets, transforms

transform = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.MNIST(
    root="./data", train=True, transform=transform, download=True
)
test_dataset = datasets.MNIST(
    root="./data", train=False, transform=transform, download=True
)

train_loader = data.DataLoader(train_dataset, batch_size=100, shuffle=True)
test_loader = data.DataLoader(test_dataset, batch_size=100, shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

方法一:线性神经网络(准确率:98.08%)

线性神经网络是最简单的深度学习模型。我们将28×28的图像展平为784维向量,经过两个全连接层进行分类。

import torch.nn as nn
import torch.optim as optim

class LinearModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 1000)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout()
        self.fc2 = nn.Linear(1000, 10)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.dropout(out)
        out = self.fc2(out)
        return out

model = LinearModel().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练20个epoch
for epoch in range(20):
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device).reshape(-1, 784)
        labels = labels.to(device)

        output = model(images)
        loss = criterion(output, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i + 1) % 100 == 0:
            print(f"Epoch [{epoch+1}/20], Loss: {loss.item():.4f}")

网络结构特点:

  • 只有全连接层,没有卷积操作
  • 无法利用像素的空间局部性
  • 参数量较大(约80万)

方法二:卷积神经网络(准确率:99.28%)

CNN通过卷积核提取图像的局部特征,是图像任务的经典选择。

class CNNModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.dropout = nn.Dropout()
        self.fc1 = nn.Linear(7 * 7 * 64, 1000)
        self.fc2 = nn.Linear(1000, 10)

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.reshape(out.size(0), -1)
        out = self.dropout(out)
        out = self.fc1(out)
        out = self.fc2(out)
        return out

网络结构特点:

  • 两层卷积-激活-池化结构
  • 卷积层后特征图尺寸:28×28 → 14×14 → 7×7
  • 利用参数共享,大幅减少参数量

方法三:Vision Transformer(准确率:98.30%)

ViT将Transformer架构引入计算机视觉,通过自注意力机制捕捉全局依赖。

class ViTModel(nn.Module):
    def __init__(self, image_size=28, patch_size=7, in_channels=1, 
                 d_model=128, n_head=4, num_layers=3, num_classes=10):
        super().__init__()
        self.num_patches = (image_size // patch_size) ** 2

        # 将图像分割为patches并线性投影
        self.patch_embed = nn.Conv2d(
            in_channels, d_model, kernel_size=patch_size, stride=patch_size
        )

        # CLS token和位置编码
        self.cls_token = nn.Parameter(torch.randn(1, 1, d_model))
        self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches + 1, d_model))

        # Transformer编码器
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=n_head, dim_feedforward=4 * d_model,
            activation="gelu", batch_first=True, dropout=0.1
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # 分类头
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(d_model), nn.Linear(d_model, num_classes)
        )

    def forward(self, x):
        batch_size = x.shape[0]
        
        # 获取patch embeddings
        x = self.patch_embed(x)
        x = x.flatten(2).permute(0, 2, 1)
        
        # 添加CLS token和位置编码
        cls_tokens = self.cls_token.repeat(batch_size, 1, 1)
        x = torch.cat([cls_tokens, x], dim=1)
        x += self.pos_embed
        
        # Transformer编码
        x = self.transformer(x)
        
        # 使用CLS token进行分类
        return self.mlp_head(x[:, 0])

# 使用AdamW优化器,添加权重衰减
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)

网络结构特点:

  • 将28×28图像分割为16个7×7的patches
  • 使用自注意力机制建模全局关系
  • 适合处理长距离依赖

性能对比

下表展示了三种方法在MNIST测试集上的准确率:

方法 5轮 10轮 15轮 20轮
线性神经网络 96.69% 97.77% 97.81% 98.08%
卷积神经网络 98.85% 99.27% 99.16% 99.28%
Vision Transformer 97.64% 98.39% 98.27% 98.30%

结果分析

CNN表现最佳:99.28%的准确率明显优于其他方法。卷积操作天然适合处理图像数据,能有效提取局部特征。

ViT表现中等:98.30%的准确率略高于线性网络,但低于CNN。原因是MNIST图像尺寸小(28×28),ViT将图像分割为patches后,可学习的patches数量有限,难以充分发挥Transformer的优势。ViT通常在ImageNet等大规模数据集上表现更好。

线性网络表现合格:98.08%的准确率虽然最低,但考虑到其结构简单、训练快速,对于MNIST这样相对简单的任务已经足够实用。

训练建议

过拟合观察:CNN在15轮后准确率略有下降(99.16% → 99.28%),可能是轻微过拟合。可考虑增加Dropout率或早停。

学习率调优:建议尝试学习率调度器(如StepLR或CosineAnnealingLR)提升收敛效果。

数据增强:可添加随机旋转、平移等增强操作进一步提升泛化能力。

总结

本文使用PyTorch实现了三种不同的手写数字识别模型。实验表明,对于MNIST数据集:

  • CNN最适合,充分利用了图像的局部相关性
  • ViT潜力大,但在小数据集上不如CNN
  • 线性网络简单有效,适合快速原型验证
最后更新于1小时前
本文由人工编写,AI优化,转载请注明原文地址: PyTorch实战对比:线性网络、CNN与ViT在手写数字识别上的性能较量

评论 (0)

登录 后发表评论

暂无评论,快来发表第一条评论吧!