模型:
facebook/contriever-msmarco
这个模型是预训练的contriever模型的精调版本,可以在这里找到 https://huggingface.co/facebook/contriever ,并根据 Towards Unsupervised Dense Information Retrieval with Contrastive Learning 中描述的方法进行处理。相关的GitHub存储库可在这里找到 https://github.com/facebookresearch/contriever 。
直接使用HuggingFace Transformers中可用的模型需要添加一个平均池化操作,以获得句子嵌入。
import torch from transformers import AutoTokenizer, AutoModel tokenizer = AutoTokenizer.from_pretrained('facebook/contriever-msmarco') model = AutoModel.from_pretrained('facebook/contriever-msmarco') sentences = [ "Where was Marie Curie born?", "Maria Sklodowska, later known as Marie Curie, was born on November 7, 1867.", "Born in Paris on 15 May 1859, Pierre Curie was the son of Eugène Curie, a doctor of French Catholic origin from Alsace." ] # Apply tokenizer inputs = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt') # Compute token embeddings outputs = model(**inputs) # Mean pooling def mean_pooling(token_embeddings, mask): token_embeddings = token_embeddings.masked_fill(~mask[..., None].bool(), 0.) sentence_embeddings = token_embeddings.sum(dim=1) / mask.sum(dim=1)[..., None] return sentence_embeddings embeddings = mean_pooling(outputs[0], inputs['attention_mask'])