模型:
facebook/rag-token-nq
这是Patrick Lewis、Ethan Perez、Aleksandara Piktus等人的论文 Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks 的RAG-Token模型。
该模型是一个非大小写敏感模型,即将大写字母转换为小写字母。
该模型包括一个question_encoder、一个retriever和一个generator。检索器从链接上方的wiki_dpr训练数据集中提取相关段落。question_encoder和retriever基于facebook/dpr-question_encoder-single-nq-base和facebook/bart-large,它们在wiki_dpr QA数据集上进行了联合微调,以端到端方式工作。
请注意:在下面的用法示例中,只使用了wiki_dpr的虚拟retriever,因为完整的遗留索引需要超过75GB的RAM。该模型可以回答任何事实型问题,方法如下:
from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq") retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True) model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever) input_dict = tokenizer.prepare_seq2seq_batch("who holds the record in 100m freestyle", return_tensors="pt") generated = model.generate(input_ids=input_dict["input_ids"]) print(tokenizer.batch_decode(generated, skip_special_tokens=True)[0]) # should give michael phelps => sounds reasonable