模型:

cross-encoder/ms-marco-MiniLM-L-4-v2

英文

MS Marco交叉编码器

该模型是在 MS Marco Passage Ranking 任务上进行训练的。

该模型可用于信息检索:给定一个查询,将查询与所有可能的段落(例如使用ElasticSearch检索到的段落)进行编码。然后按降序对段落进行排序。更多详细信息请参见 SBERT.net Retrieve & Re-rank 。训练代码在此处可用: SBERT.net Training MS Marco

使用Transformers

from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

model = AutoModelForSequenceClassification.from_pretrained('model_name')
tokenizer = AutoTokenizer.from_pretrained('model_name')

features = tokenizer(['How many people live in Berlin?', 'How many people live in Berlin?'], ['Berlin has a population of 3,520,031 registered inhabitants in an area of 891.82 square kilometers.', 'New York City is famous for the Metropolitan Museum of Art.'],  padding=True, truncation=True, return_tensors="pt")

model.eval()
with torch.no_grad():
    scores = model(**features).logits
    print(scores)

使用SentenceTransformers

如果您已经安装了 SentenceTransformers ,则使用预训练模型会更加简单。然后,您可以像这样使用预训练模型:

from sentence_transformers import CrossEncoder
model = CrossEncoder('model_name', max_length=512)
scores = model.predict([('Query', 'Paragraph1'), ('Query', 'Paragraph2') , ('Query', 'Paragraph3')])

性能

在下表中,我们提供了各种预训练的交叉编码器以及它们在 TREC Deep Learning 2019 MS Marco Passage Reranking 数据集上的性能。

Model-Name NDCG@10 (TREC DL 19) MRR@10 (MS Marco Dev) Docs / Sec
Version 2 models
cross-encoder/ms-marco-TinyBERT-L-2-v2 69.84 32.56 9000
cross-encoder/ms-marco-MiniLM-L-2-v2 71.01 34.85 4100
cross-encoder/ms-marco-MiniLM-L-4-v2 73.04 37.70 2500
cross-encoder/ms-marco-MiniLM-L-6-v2 74.30 39.01 1800
cross-encoder/ms-marco-MiniLM-L-12-v2 74.31 39.02 960
Version 1 models
cross-encoder/ms-marco-TinyBERT-L-2 67.43 30.15 9000
cross-encoder/ms-marco-TinyBERT-L-4 68.09 34.50 2900
cross-encoder/ms-marco-TinyBERT-L-6 69.57 36.13 680
cross-encoder/ms-marco-electra-base 71.99 36.41 340
Other models
nboost/pt-tinybert-msmarco 63.63 28.80 2900
nboost/pt-bert-base-uncased-msmarco 70.94 34.75 340
nboost/pt-bert-large-msmarco 73.36 36.48 100
Capreolus/electra-base-msmarco 71.23 36.89 340
amberoad/bert-multilingual-passage-reranking-msmarco 68.40 35.54 330
sebastian-hofstaetter/distilbert-cat-margin_mse-T2-msmarco 72.82 37.88 720

注意:运行时间是在V100 GPU上计算的。