模型:
cmarkea/distilcamembert-base
我们提供了一种名为 CamemBERT 的DistilCamemBERT的蒸馏版本,它是一种RoBERTa的法语模型版本。蒸馏的目的是大幅减少模型的复杂性,同时保持性能。我们在 DistilBERT paper 中展示了这个概念验证,并且训练所使用的代码是受到 DistilBERT 代码的启发的。
对于蒸馏模型(学生模型)的训练旨在尽可能接近原始模型(教师模型)。为了实现这一点,损失函数由3部分组成:
最终的损失函数是这三个损失函数的组合。我们使用以下权重:
L o s s = 0.5 × D i s t i l L o s s + 0.3 × C o s i n e L o s s + 0.2 × M L M L o s s
为了减少学生和教师模型之间的偏差,用于DistilCamemBERT训练的数据集与camembert-base训练使用的数据集相同:OSCAR。该数据集的法语部分在硬盘上约占据140 GB的空间。
我们在一台nVidia Titan RTX上进行了18天的预训练。
Dataset name | f1-score |
---|---|
1236321 CLS | 83% |
1236321 PAWS-X | 77% |
1236321 XNLI | 77% |
1239321 NER | 98% |
加载DistilCamemBERT及其子词分词器:
from transformers import AutoTokenizer, AutoModel tokenizer = AutoTokenizer.from_pretrained("cmarkea/distilcamembert-base") model = AutoModel.from_pretrained("cmarkea/distilcamembert-base") model.eval() ...
使用管道填充掩码:
from transformers import pipeline model_fill_mask = pipeline("fill-mask", model="cmarkea/distilcamembert-base", tokenizer="cmarkea/distilcamembert-base") results = model_fill_mask("Le camembert est <mask> :)") results [{'sequence': '<s> Le camembert est délicieux :)</s>', 'score': 0.3878222405910492, 'token': 7200}, {'sequence': '<s> Le camembert est excellent :)</s>', 'score': 0.06469205021858215, 'token': 2183}, {'sequence': '<s> Le camembert est parfait :)</s>', 'score': 0.04534877464175224, 'token': 1654}, {'sequence': '<s> Le camembert est succulent :)</s>', 'score': 0.04128391295671463, 'token': 26202}, {'sequence': '<s> Le camembert est magnifique :)</s>', 'score': 0.02425697259604931, 'token': 1509}]
@inproceedings{delestre:hal-03674695, TITLE = {{DistilCamemBERT : une distillation du mod{\`e}le fran{\c c}ais CamemBERT}}, AUTHOR = {Delestre, Cyrile and Amar, Abibatou}, URL = {https://hal.archives-ouvertes.fr/hal-03674695}, BOOKTITLE = {{CAp (Conf{\'e}rence sur l'Apprentissage automatique)}}, ADDRESS = {Vannes, France}, YEAR = {2022}, MONTH = Jul, KEYWORDS = {NLP ; Transformers ; CamemBERT ; Distillation}, PDF = {https://hal.archives-ouvertes.fr/hal-03674695/file/cap2022.pdf}, HAL_ID = {hal-03674695}, HAL_VERSION = {v1}, }