模型:

google/fnet-base

英文

FNet基础模型

使用掩码语言建模(MLM)和下一个句子预测(NSP)目标在英语语言上预训练的模型。它于 this paper 年发布,并于 this repository 首次发布。此模型区分大小写:它区分英语和 English。该模型在MLM目标上实现了0.58的准确率,在NSP目标上实现了0.80的准确率。

免责声明:此模型卡片由 gchhablani 编写。

模型描述

FNet是一个基于Fourier变换替代注意力的Transformer模型。因此,输入不包含attention_mask。它是在自监督方式下对大量英语数据进行预训练。这意味着它仅使用原始文本进行预训练,没有任何人为标记的方式(这就是为什么它可以使用大量的公开可用数据),使用自动过程从这些文本中生成输入和标签。更准确地说,它通过两个目标进行预训练:

  • 掩码语言建模(MLM):对于一个句子,模型会随机掩码输入中的15%的单词,然后将整个掩码句子输入模型,并尝试预测掩码的单词。这与传统的循环神经网络(RNN)不同,传统RNN通常按照单词的顺序接收输入,或者像GPT这样的自回归模型会内部掩码未来的标记。使用这种方式,模型可以学习到句子的双向表示。
  • 下一个句子预测(NSP):在预训练期间,模型将两个被掩码的句子连接在一起作为输入。这些句子有时与原始文本中相邻的句子对应,有时不对应。模型需要预测这两个句子是否相邻。

通过这种方式,模型学习了英语语言的内部表示,可以用于提取对下游任务有用的特征:例如,如果你有一个带有标记句子的数据集,你可以使用FNet模型生成的特征作为输入来训练标准分类器。

预期用途和限制

您可以直接使用原始模型进行掩码语言建模或下一个句子预测,但它主要用于在下游任务上进行微调。您可以查看感兴趣的任务的fine-tuned版本。

请注意,该模型主要用于对整个句子(可能被掩码)进行决策的任务,例如序列分类、标记分类或问题回答。对于文本生成等任务,您应该看一下像GPT2这样的模型。

训练数据

FNet模型是在 C4 上预训练的,这是Common Crawl数据集的清理版本。

训练过程

预处理

文本经过小写处理并使用SentencePiece和词汇表大小为32,000进行分词。模型的输入形式如下:

[CLS] Sentence A [SEP] Sentence B [SEP]

以0.5的概率,句子A和句子B对应于原始语料库中的两个连续的句子,而在其他情况下,它们是语料库中的另一个随机句子。请注意,这里所指的句子是连续的文本段通常比单个句子长。唯一的约束是两个"句子"的结果的长度之和小于512个标记。

每个句子的掩码过程的详细情况如下:

  • 占15%的标记被掩码。
  • 在80%的情况下,被掩码的标记被替换为 [MASK]。
  • 在剩下的10%的情况下,被掩码的标记被替换为一个与其不同的随机标记。
  • 在剩下的10%的情况下,被掩码的标记保持不变。

预训练

FNet-base使用16个TPU芯片的Pod配置进行100万步训练,批量大小为256。序列长度被限制为512个标记。采用Adam优化器,学习率为1e-4, β 1 = 0.9, β 2 = 0.999,权重衰减为0.01,学习率预热步数为10,000步,之后的学习率线性衰减。

评估结果

FNet-base在 GLUE benchamrk 的验证数据集上进行了微调和评估。官方模型(使用Flax编写)的结果可在 the official paper 的第7页的表1中看到。

为了比较,该模型(转换为PyTorch版本)使用了 official Hugging Face GLUE evaluation scripts 并与 bert-base-cased 进行了比较。在单个16GB的NVIDIA Tesla V100 GPU上进行训练。对于MRPC/WNLI,模型训练了5个epochs,而其他任务训练了3个epochs。使用512的序列长度,批量大小为16,学习率为2e-5。

以下表格总结了 fnet-base (称为 FNet(PyTorch)- Reproduced )和 bert-base-cased (称为 Bert(PyTorch)- Reproduced )在微调速度方面的结果,并将其与官方FNet-base模型(称为 FNet(Flax)- Official )的表现进行了比较。请注意,重现模型的训练超参数与官方模型不同,因此某些任务的性能可能存在明显差异(例如:CoLA)。

Task/Model FNet-base (PyTorch) Bert-base (PyTorch)
MNLI-(m/mm) 12316321 12317321
QQP 12318321 12319321
QNLI 12320321 12321321
SST-2 12322321 12323321
CoLA 12324321 12325321
STS-B 12326321 12327321
MRPC 12328321 12329321
RTE 12330321 12331321
WNLI 12332321 12333321
SUM 16:30:45 24:23:56

可以看到,FNet-base的平均性能约为BERT-base的93%。

有关更多详细信息,请参阅与分数相关联的检查点。可以访问以下表格的所有微调检查点概述 here

如何使用

您可以使用此模型直接进行掩码语言建模的流水线处理:

注意:掩码填充流程与原始模型的掩码操作不完全相同。在掩码流水线中,在 [MASK] 后额外添加了一个空格。

>>> from transformers import FNetForMaskedLM, FNetTokenizer, pipeline
>>> tokenizer = FNetTokenizer.from_pretrained("google/fnet-base")
>>> model = FNetForMaskedLM.from_pretrained("google/fnet-base")
>>> unmasker = pipeline('fill-mask', model=model, tokenizer=tokenizer)
>>> unmasker("Hello I'm a [MASK] model.")

[
    {"sequence": "hello i'm a new model.", "score": 0.12073223292827606, "token": 351, "token_str": "new"},
    {"sequence": "hello i'm a first model.", "score": 0.08501081168651581, "token": 478, "token_str": "first"},
    {"sequence": "hello i'm a next model.", "score": 0.060546260327100754, "token": 1037, "token_str": "next"},
    {"sequence": "hello i'm a last model.", "score": 0.038265593349933624, "token": 813, "token_str": "last"},
    {"sequence": "hello i'm a sister model.", "score": 0.033868927508592606, "token": 6232, "token_str": "sister"},
]

以下是在PyTorch中使用此模型获取给定文本特征的方法:

注意:您必须将最大序列长度指定为512,并对齐/填充为相同的长度,因为原始模型没有attention_mask,并在前向传递过程中考虑所有隐藏状态。

from transformers import FNetTokenizer, FNetModel
tokenizer = FNetTokenizer.from_pretrained("google/fnet-base")
model = FNetModel.from_pretrained("google/fnet-base")
text = "Replace me by any text you'd like."
encoded_input = tokenizer(text, return_tensors='pt', padding='max_length', truncation=True, max_length=512)
output = model(**encoded_input)

BibTeX条目和引文信息

@article{DBLP:journals/corr/abs-2105-03824,
  author    = {James Lee{-}Thorp and
               Joshua Ainslie and
               Ilya Eckstein and
               Santiago Onta{\~{n}}{\'{o}}n},
  title     = {FNet: Mixing Tokens with Fourier Transforms},
  journal   = {CoRR},
  volume    = {abs/2105.03824},
  year      = {2021},
  url       = {https://arxiv.org/abs/2105.03824},
  archivePrefix = {arXiv},
  eprint    = {2105.03824},
  timestamp = {Fri, 14 May 2021 12:13:30 +0200},
  biburl    = {https://dblp.org/rec/journals/corr/abs-2105-03824.bib},
  bibsource = {dblp computer science bibliography, https://dblp.org}
}

贡献

感谢 @gchhablani 添加了此模型。