模型:
facebook/contriever-msmarco
这个模型是预训练的contriever模型的精调版本,可以在这里找到 https://huggingface.co/facebook/contriever ,并根据 Towards Unsupervised Dense Information Retrieval with Contrastive Learning 中描述的方法进行处理。相关的GitHub存储库可在这里找到 https://github.com/facebookresearch/contriever 。
直接使用HuggingFace Transformers中可用的模型需要添加一个平均池化操作,以获得句子嵌入。
import torch
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained('facebook/contriever-msmarco')
model = AutoModel.from_pretrained('facebook/contriever-msmarco')
sentences = [
    "Where was Marie Curie born?",
    "Maria Sklodowska, later known as Marie Curie, was born on November 7, 1867.",
    "Born in Paris on 15 May 1859, Pierre Curie was the son of Eugène Curie, a doctor of French Catholic origin from Alsace."
]
# Apply tokenizer
inputs = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
# Compute token embeddings
outputs = model(**inputs)
# Mean pooling
def mean_pooling(token_embeddings, mask):
    token_embeddings = token_embeddings.masked_fill(~mask[..., None].bool(), 0.)
    sentence_embeddings = token_embeddings.sum(dim=1) / mask.sum(dim=1)[..., None]
    return sentence_embeddings
embeddings = mean_pooling(outputs[0], inputs['attention_mask'])