这是在 MultiNLI (MNLI) 数据集上训练后的检查点 bart-large 。
关于这个模型的额外信息:
Yin et al. 提出了一种使用预训练的NLI模型作为现成的零样本序列分类器的方法。该方法通过将待分类的序列作为NLI前提,并为每个候选标签构建一个假设。例如,如果我们想评估一个序列是否属于"政治"类,我们可以构建一个假设:这段文本是关于政治的。然后将蕴涵和矛盾的概率转换为标签概率。
这种方法在许多情况下效果出人意料地好,尤其是在使用像BART和Roberta这样的较大预训练模型时。有关此方法和其他零样本方法的更详尽介绍,请参阅 this blog post ;有关使用此模型进行零样本分类的示例的代码片段,请参阅下面使用Hugging Face内置流水线和本机Transformers/PyTorch代码的示例。
With the zero-shot classification pipeline可以使用zero-shot-classification流水线加载模型,如下所示:
from transformers import pipeline classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
然后,您可以使用此流水线将序列分类为您指定的任何类名。
sequence_to_classify = "one day I will see the world" candidate_labels = ['travel', 'cooking', 'dancing'] classifier(sequence_to_classify, candidate_labels) #{'labels': ['travel', 'dancing', 'cooking'], # 'scores': [0.9938651323318481, 0.0032737774308770895, 0.002861034357920289], # 'sequence': 'one day I will see the world'}
如果有多个候选标签可能是正确的,请传递multi_class=True以独立计算每个类别:
candidate_labels = ['travel', 'cooking', 'dancing', 'exploration'] classifier(sequence_to_classify, candidate_labels, multi_class=True) #{'labels': ['travel', 'exploration', 'dancing', 'cooking'], # 'scores': [0.9945111274719238, # 0.9383890628814697, # 0.0057061901316046715, # 0.0018193122232332826], # 'sequence': 'one day I will see the world'}With manual PyTorch
# pose sequence as a NLI premise and label as a hypothesis from transformers import AutoModelForSequenceClassification, AutoTokenizer nli_model = AutoModelForSequenceClassification.from_pretrained('facebook/bart-large-mnli') tokenizer = AutoTokenizer.from_pretrained('facebook/bart-large-mnli') premise = sequence hypothesis = f'This example is {label}.' # run through model pre-trained on MNLI x = tokenizer.encode(premise, hypothesis, return_tensors='pt', truncation_strategy='only_first') logits = nli_model(x.to(device))[0] # we throw away "neutral" (dim 1) and take the probability of # "entailment" (2) as the probability of the label being true entail_contradiction_logits = logits[:,[0,2]] probs = entail_contradiction_logits.softmax(dim=1) prob_label_is_true = probs[:,1]