模型:
facebook/data2vec-audio-base-960h
基于Librispeech的960小时16kHz采样语音音频进行预训练和微调的基础模型。使用该模型时,请确保语音输入也以16Khz进行采样。
作者:Alexei Baevski,Wei-Ning Hsu,Qiantong Xu,Arun Babu,Jiatao Gu,Michael Auli
摘要
尽管跨模态的自监督学习的总体思想是相同的,但实际的算法和目标因不同模态而异。为了让我们更接近通用的自监督学习,我们提出了data2vec,这是一个框架,可以在语音、自然语言处理或计算机视觉中使用相同的学习方法。其核心思想是在自蒸馏设置中使用标准Transformer架构,基于输入的屏蔽视图来预测完整输入数据的潜在表示。data2vec不是预测特定于模态的目标,例如单词、视觉标记或人声单位,而是预测包含来自整个输入的信息的上下文化潜在表示。在语音识别、图像分类和自然语言理解的主要基准测试上进行的实验表明,data2vec取得了新的最先进或与主导方法相比具有竞争力的性能。
原始模型可在 https://github.com/pytorch/fairseq/tree/main/examples/data2vec 下找到。
更多信息,请参阅 official paper 。
要转录音频文件,可以将模型用作独立的声学模型,如下所示:
from transformers import Wav2Vec2Processor, Data2VecForCTC from datasets import load_dataset import torch # load model and processor processor = Wav2Vec2Processor.from_pretrained("facebook/data2vec-audio-base-960h") model = Data2VecForCTC.from_pretrained("facebook/data2vec-audio-base-960h") # load dummy dataset and read soundfiles ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation") # tokenize input_values = processor(ds[0]["audio"]["array"],, return_tensors="pt", padding="longest").input_values # Batch size 1 # retrieve logits logits = model(input_values).logits # take argmax and decode predicted_ids = torch.argmax(logits, dim=-1) transcription = processor.batch_decode(predicted_ids)
此代码片段显示如何在LibriSpeech的“clean”和“other”测试数据上评估facebook/data2vec-audio-base-960h。
from transformers import Wav2Vec2Processor, Data2VecForCTC from datasets import load_dataset import torch from jiwer import wer # load model and processor processor = Wav2Vec2Processor.from_pretrained("facebook/data2vec-audio-base-960h").to("cuda") model = Data2VecForCTC.from_pretrained("facebook/data2vec-audio-base-960h") librispeech_eval = load_dataset("librispeech_asr", "clean", split="test") def map_to_pred(batch): input_values = processor(batch["audio"]["array"], return_tensors="pt", padding="longest").input_values with torch.no_grad(): logits = model(input_values.to("cuda")).logits predicted_ids = torch.argmax(logits, dim=-1) transcription = processor.batch_decode(predicted_ids) batch["transcription"] = transcription return batch result = librispeech_eval.map(map_to_pred, batched=True, batch_size=1, remove_columns=["audio"]) print("WER:", wer(result["text"], result["transcription"]))
结果(WER):
"clean" | "other" |
---|---|
2.77 | 7.08 |