模型:
setu4993/LaBSE
语言无关的BERT句子编码器(LaBSE)是一个基于BERT的模型,用于109种语言的句子嵌入。预训练过程将掩码语言建模与翻译语言建模相结合。该模型对于获取多语言句子嵌入和双向文本检索非常有用。
这是从TF Hub的v2模型迁移而来的,该模型使用基于字典的输入。两个版本的模型产生的嵌入结果是 equivalent .
使用该模型:
import torch from transformers import BertModel, BertTokenizerFast tokenizer = BertTokenizerFast.from_pretrained("setu4993/LaBSE") model = BertModel.from_pretrained("setu4993/LaBSE") model = model.eval() english_sentences = [ "dog", "Puppies are nice.", "I enjoy taking long walks along the beach with my dog.", ] english_inputs = tokenizer(english_sentences, return_tensors="pt", padding=True) with torch.no_grad(): english_outputs = model(**english_inputs)
要获取句子嵌入,请使用汇聚器输出:
english_embeddings = english_outputs.pooler_output
其他语言的输出:
italian_sentences = [ "cane", "I cuccioli sono carini.", "Mi piace fare lunghe passeggiate lungo la spiaggia con il mio cane.", ] japanese_sentences = ["犬", "子犬はいいです", "私は犬と一緒にビーチを散歩するのが好きです"] italian_inputs = tokenizer(italian_sentences, return_tensors="pt", padding=True) japanese_inputs = tokenizer(japanese_sentences, return_tensors="pt", padding=True) with torch.no_grad(): italian_outputs = model(**italian_inputs) japanese_outputs = model(**japanese_inputs) italian_embeddings = italian_outputs.pooler_output japanese_embeddings = japanese_outputs.pooler_output
对于句子之间的相似度,建议在计算相似度之前进行L2-norm:
import torch.nn.functional as F def similarity(embeddings_1, embeddings_2): normalized_embeddings_1 = F.normalize(embeddings_1, p=2) normalized_embeddings_2 = F.normalize(embeddings_2, p=2) return torch.matmul( normalized_embeddings_1, normalized_embeddings_2.transpose(0, 1) ) print(similarity(english_embeddings, italian_embeddings)) print(similarity(english_embeddings, japanese_embeddings)) print(similarity(italian_embeddings, japanese_embeddings))
有关数据、训练、评估和性能指标的详细信息,请参阅 original paper .
@misc{feng2020languageagnostic, title={Language-agnostic BERT Sentence Embedding}, author={Fangxiaoyu Feng and Yinfei Yang and Daniel Cer and Naveen Arivazhagan and Wei Wang}, year={2020}, eprint={2007.01852}, archivePrefix={arXiv}, primaryClass={cs.CL} }