英文

Data2Vec-Audio-Base-960h

Facebook's Data2Vec

基于Librispeech的960小时16kHz采样语音音频进行预训练和微调的基础模型。使用该模型时,请确保语音输入也以16Khz进行采样。

Paper

作者:Alexei Baevski,Wei-Ning Hsu,Qiantong Xu,Arun Babu,Jiatao Gu,Michael Auli

摘要

尽管跨模态的自监督学习的总体思想是相同的,但实际的算法和目标因不同模态而异。为了让我们更接近通用的自监督学习,我们提出了data2vec,这是一个框架,可以在语音、自然语言处理或计算机视觉中使用相同的学习方法。其核心思想是在自蒸馏设置中使用标准Transformer架构,基于输入的屏蔽视图来预测完整输入数据的潜在表示。data2vec不是预测特定于模态的目标,例如单词、视觉标记或人声单位,而是预测包含来自整个输入的信息的上下文化潜在表示。在语音识别、图像分类和自然语言理解的主要基准测试上进行的实验表明,data2vec取得了新的最先进或与主导方法相比具有竞争力的性能。

原始模型可在 https://github.com/pytorch/fairseq/tree/main/examples/data2vec 下找到。

预训练方法

更多信息,请参阅 official paper

用法

要转录音频文件,可以将模型用作独立的声学模型,如下所示:

 from transformers import Wav2Vec2Processor, Data2VecForCTC
 from datasets import load_dataset
 import torch
 
 # load model and processor
 processor = Wav2Vec2Processor.from_pretrained("facebook/data2vec-audio-base-960h")
 model = Data2VecForCTC.from_pretrained("facebook/data2vec-audio-base-960h")
     
 # load dummy dataset and read soundfiles
 ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
 
 # tokenize
 input_values = processor(ds[0]["audio"]["array"],, return_tensors="pt", padding="longest").input_values  # Batch size 1
 
 # retrieve logits
 logits = model(input_values).logits
 
 # take argmax and decode
 predicted_ids = torch.argmax(logits, dim=-1)
 transcription = processor.batch_decode(predicted_ids)

评估

此代码片段显示如何在LibriSpeech的“clean”和“other”测试数据上评估facebook/data2vec-audio-base-960h。

 from transformers import Wav2Vec2Processor, Data2VecForCTC
 from datasets import load_dataset
 import torch
 from jiwer import wer
 
 # load model and processor
 processor = Wav2Vec2Processor.from_pretrained("facebook/data2vec-audio-base-960h").to("cuda")
 model = Data2VecForCTC.from_pretrained("facebook/data2vec-audio-base-960h")
 

librispeech_eval = load_dataset("librispeech_asr", "clean", split="test")

def map_to_pred(batch):
    input_values = processor(batch["audio"]["array"], return_tensors="pt", padding="longest").input_values
    with torch.no_grad():
        logits = model(input_values.to("cuda")).logits

    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = processor.batch_decode(predicted_ids)
    batch["transcription"] = transcription
    return batch

result = librispeech_eval.map(map_to_pred, batched=True, batch_size=1, remove_columns=["audio"])

print("WER:", wer(result["text"], result["transcription"]))

结果(WER):

"clean" "other"
2.77 7.08