模型:

cmarkea/distilcamembert-base

英文

DistilCamemBERT

我们提供了一种名为 CamemBERT 的DistilCamemBERT的蒸馏版本,它是一种RoBERTa的法语模型版本。蒸馏的目的是大幅减少模型的复杂性,同时保持性能。我们在 DistilBERT paper 中展示了这个概念验证,并且训练所使用的代码是受到 DistilBERT 代码的启发的。

损失函数

对于蒸馏模型(学生模型)的训练旨在尽可能接近原始模型(教师模型)。为了实现这一点,损失函数由3部分组成:

  • DistilLoss:一种蒸馏损失,用于衡量学生和教师模型输出概率之间的相似度,并在MLM任务上使用交叉熵损失。
  • CosineLoss:余弦嵌入损失。这个损失函数应用于学生和教师模型的最后隐藏层,以保证它们之间的共线性。
  • MLMLoss:最后是一项遮蔽语言模型(MLM)任务损失,以使用教师模型的原始任务执行学生模型。

最终的损失函数是这三个损失函数的组合。我们使用以下权重:

L o s s = 0.5 × D i s t i l L o s s + 0.3 × C o s i n e L o s s + 0.2 × M L M L o s s

数据集

为了减少学生和教师模型之间的偏差,用于DistilCamemBERT训练的数据集与camembert-base训练使用的数据集相同:OSCAR。该数据集的法语部分在硬盘上约占据140 GB的空间。

训练

我们在一台nVidia Titan RTX上进行了18天的预训练。

评估结果

Dataset name f1-score
1236321 CLS 83%
1236321 PAWS-X 77%
1236321 XNLI 77%
1239321 NER 98%

如何使用DistilCamemBERT

加载DistilCamemBERT及其子词分词器:

from transformers import AutoTokenizer, AutoModel

tokenizer = AutoTokenizer.from_pretrained("cmarkea/distilcamembert-base")
model = AutoModel.from_pretrained("cmarkea/distilcamembert-base")
model.eval()
...

使用管道填充掩码:

from transformers import pipeline

model_fill_mask = pipeline("fill-mask", model="cmarkea/distilcamembert-base", tokenizer="cmarkea/distilcamembert-base")
results = model_fill_mask("Le camembert est <mask> :)")

results
[{'sequence': '<s> Le camembert est délicieux :)</s>', 'score': 0.3878222405910492, 'token': 7200},
 {'sequence': '<s> Le camembert est excellent :)</s>', 'score': 0.06469205021858215, 'token': 2183}, 
 {'sequence': '<s> Le camembert est parfait :)</s>', 'score': 0.04534877464175224, 'token': 1654}, 
 {'sequence': '<s> Le camembert est succulent :)</s>', 'score': 0.04128391295671463, 'token': 26202}, 
 {'sequence': '<s> Le camembert est magnifique :)</s>', 'score': 0.02425697259604931, 'token': 1509}]

引用

@inproceedings{delestre:hal-03674695,
  TITLE = {{DistilCamemBERT : une distillation du mod{\`e}le fran{\c c}ais CamemBERT}},
  AUTHOR = {Delestre, Cyrile and Amar, Abibatou},
  URL = {https://hal.archives-ouvertes.fr/hal-03674695},
  BOOKTITLE = {{CAp (Conf{\'e}rence sur l'Apprentissage automatique)}},
  ADDRESS = {Vannes, France},
  YEAR = {2022},
  MONTH = Jul,
  KEYWORDS = {NLP ; Transformers ; CamemBERT ; Distillation},
  PDF = {https://hal.archives-ouvertes.fr/hal-03674695/file/cap2022.pdf},
  HAL_ID = {hal-03674695},
  HAL_VERSION = {v1},
}