CNN 和 Vision Transformer:分析与比较
2023年06月30日 由 Alex 发表
990976
0
探究Vision Transformers和卷积神经网络(CNNs)在图像分类任务中的有效性。
图像分类是计算机视觉中的关键任务,广泛应用于工业、医学影像和农业等各个领域。卷积神经网络(CNNs)在这个领域取得了重大突破,并被广泛使用。然而,随着论文《Attention is all you need》的出现,行业开始向Transformer转变。Transformer在人工智能和数据科学领域取得了显著进展,例如ChatGPT在语言生成任务中的出色表现就是Transformer有效性的最新例证。类似地,《ViT》论文提供了Vision Transformer的概览。在本文中,我将尝试比较CNNs和ViTs(Vision Transformers)在Food-101数据集上进行图像分类任务时的性能。值得注意的是,选择使用CNNs还是ViTs取决于多个因素,包括工作类型、训练时间和计算能力,并且我们不能直接断言Transformers比CNNs更好。该分析旨在提供关于它们在这个特定任务中性能的见解。
数据集
由于计算能力有限,我将易于访问的 Food-101 数据集(包含大约 101,000 张图像)划分为 10 个类别。
我将数据集划分为以下10个类别:
['samosa','pizza','red_velvet_cake', 'tacos', 'miso_soup', 'onion_rings', 'ramen', 'nachos', 'omelette', 'ice_cream']
将图像进行转换和调整大小,调整为256x256,并进行归一化处理,均值为0,方差为1。在对数据集进行子集划分之后,将数据集分为训练集和验证集,其中训练集包含7500个图像,测试集包含2500个图像。
这些是数据集中的示例图像:
为了比较CNN和ViT的性能,我使用了预训练的DenseNet121架构作为CNN的模型,而使用了ViT-16作为Vision Transformers的模型。选择DenseNet121是基于其具有121层的密集结构,使其成为与ViT在训练时间、层数、硬件和内存需求方面进行比较的合适候选模型。对于ViT,我使用了ViT-Base模型,它包含12层和86M个参数。
DenseNet121
DenseNet-121是一个非常著名的用于图像分类的CNN架构,它是DenseNet模型系列的一部分,旨在解决非常深的神经网络中可能出现的梯度消失问题。它有121层,结合了卷积层、池化层和全连接层。模型由4个密集块组成,每个密集块包含多个具有BatchNorm和ReLU激活的卷积层。在密集块之间,使用pooing操作的转换层来降低特征图的空间维度。下面是DenseNet的架构图:
使用了PyTorch中预训练的模型。该模型经过了10个epochs的训练。
# Constants
NUM_CLASSES = 10
LEARNING_RATE = 0.001
# Model
densenet = torch.hub.load('pytorch/vision:v0.10.0', 'densenet121', pretrained=True)
for param in densenet.parameters():
param.requires_grad = False
# Change classifier layer
densenet.classifier = nn.Linear(1024,NUM_CLASSES)
# Loss, Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(densenet.classifier.parameters(), lr=LEARNING_RATE)
准确率和损失值随epochs的变化的图表:
在最后一个epoch中,训练损失率为0.3671,测试损失率为0.3586,训练准确率为88.29%,测试准确率为87.72%。
分类报告:
ViT-16
ViT-16是Vision Transformer(ViT)的一个变种,在ViT论文发布后因其在各种图像分类基准上取得最先进结果而受到关注。ViT-16由一个Transformer编码器和一个用于分类的多层感知机(MLP)组成。Transformer编码器由一系列16个相同的Transformer层组成,其中每个层包含自注意机制和前馈神经网络。网络的输入是一个扁平化的图像补丁序列,该序列是通过将输入图像分成不重叠的补丁并将每个补丁扁平化为一个向量来获得的。
每个Transformer层中的自注意机制使网络在进行预测时能够专注于图像的不同部分。具体而言,它计算输入序列中每对位置的注意力权重,使网络能够根据它们与当前分类任务的相关性而关注不同的补丁。然后,在每个Transformer层中,前馈神经网络对自注意机制的输出应用非线性变换。
在经过Transformer编码器后,输出通过MLP分类器。该分类器由两个具有ReLU激活函数的全连接层和一个用于分类的softmax输出层组成。MLP将最终Transformer层的输出作为输入,将其映射为输出类别的概率分布。
以下是ViT的架构:
在将图像传递给Transformer编码器模型之前,我们需要将输入图像划分为补丁,然后将补丁扁平化。以下是图像划分为补丁的示例:
我从头开始构建了Transformer模型,但是表现并不理想。然后,我尝试了迁移学习,并使用了预训练的ViT-16模型和PyTorch的默认权重。我还对适用于ViT的图像应用了变换操作。
# Default weights
pretrained_weights = torchvision.models.ViT_B_16_Weights.DEFAULT
# Model
vit = vit_b_16(weights=pretrained_weights).to(device)
for parameter in vit.parameters():
parameter.requires_grad=False
# Change last layer
vit.heads = nn.Linear(in_features=768, out_features=10)
# Auto Transforms
vit_transforms = pretrained_weights.transforms()
准确率与迭代次数和损失值与迭代次数之间的图形:
在最后一个epoch中,训练损失率为0.1203,测试损失率为0.01893,训练准确率为96.89%,测试准确率为93.63%。
分类报告:
预测结果:
以下是使用未见过的数据对ViT-16模型进行的一些预测结果:
披萨
拉面
萨莫萨三角饺
在大多数情况下,ViT-16能够正确分类未见过的数据。
结论:
在这个特定的任务中,从图像分类的角度来看,ViT-16的性能优于DenseNet121。准确率和曲线图也显示了两者之间的显著差异。分类报告显示,与DenseNet相比,ViT的f1更好。
然而,需要注意的是,虽然Vision Transformers在某些情况下可能优于CNN,但不能一概而论地认为它们比CNN架构更好。每种架构的性能取决于多种因素,例如用例、数据规模、训练时间、参数调优、硬件的内存和计算能力等。
参考文献:
1. 《Attention is all you need》论文 https://arxiv.org/abs/1706.03762
2. 《DenseNet》论文 - https://arxiv.org/pdf/1608.06993.pdf
3. 《Vision Transformers》论文 - https://arxiv.org/pdf/2010.11929.pdf
来源:https://medium.com/@vikrampande783/cnns-and-vision-transformers-analysis-and-comparison-bf7b109718ba