Ajuste fino de XLSR-Wav2Vec2 para ASR de bajo recurso con 🤗 Transformers

Ajuste fino de XLSR-Wav2Vec2 para ASR de bajo recurso.

Nuevo (11/2021): Esta publicación del blog ha sido actualizada para presentar el sucesor de XLSR, llamado XLS-R.

Wav2Vec2 es un modelo preentrenado para el Reconocimiento Automático de Voz (ASR, por sus siglas en inglés) y fue lanzado en septiembre de 2020 por Alexei Baevski, Michael Auli y Alex Conneau. Poco después de demostrar el rendimiento superior de Wav2Vec2 en uno de los conjuntos de datos en inglés más populares para ASR, llamado LibriSpeech, Facebook AI presentó una versión multilingüe de Wav2Vec2, llamada XLSR. XLSR significa representaciones de voz cruzadas y se refiere a la capacidad del modelo para aprender representaciones de voz que son útiles en varios idiomas.

El sucesor de XLSR, simplemente llamado XLS-R (refiriéndose a “XLM-R para el habla”), fue lanzado en noviembre de 2021 por Arun Babu, Changhan Wang, Andros Tjandra, et al. XLS-R utilizó casi medio millón de horas de datos de audio en 128 idiomas para el preentrenamiento auto-supervisado y viene en tamaños que van desde 300 millones hasta dos mil millones de parámetros. Puedes encontrar los puntos de control preentrenados en el 🤗 Hub:

  • Wav2Vec2-XLS-R-300M
  • Wav2Vec2-XLS-R-1B
  • Wav2Vec2-XLS-R-2B

Similar al objetivo de modelado de lenguaje enmascarado de BERT, XLS-R aprende representaciones de voz contextualizadas mediante el enmascaramiento aleatorio de vectores de características antes de pasarlos a una red transformadora durante el preentrenamiento auto-supervisado (es decir, diagrama a la izquierda a continuación).

Para el ajuste fino, se agrega una capa lineal única encima de la red preentrenada para entrenar el modelo con datos etiquetados de tareas de audio downstream, como el reconocimiento de voz, la traducción de voz y la clasificación de audio (es decir, diagrama a la derecha a continuación).

XLS-R muestra mejoras impresionantes sobre los resultados anteriores del estado del arte tanto en reconocimiento de voz, traducción de voz como en identificación de hablante/idioma, cf. con las Tablas 3-6, 7-10 y 11-12 respectivamente del artículo oficial.

Configuración

En este blog, daremos una explicación detallada de cómo se puede ajustar fino XLS-R, más específicamente el punto de control preentrenado Wav2Vec2-XLS-R-300M, para ASR.

Para fines de demostración, ajustaremos fino el modelo en el conjunto de datos de bajo recurso de ASR de Common Voice, que contiene solo aproximadamente 4 horas de datos de entrenamiento validados.

XLS-R se ajusta fino utilizando la Clasificación Temporal Conexionalista (CTC, por sus siglas en inglés), que es un algoritmo que se utiliza para entrenar redes neuronales en problemas de secuencia a secuencia, como ASR y reconocimiento de escritura a mano.

Recomiendo encarecidamente leer la bien escrita publicación del blog “Modelado de Secuencias con CTC (2017)” de Awni Hannun.

Antes de comenzar, instalemos datasets y transformers. Además, necesitamos torchaudio para cargar archivos de audio y jiwer para evaluar nuestro modelo ajustado fino utilizando la métrica de tasa de error de palabras (WER) 1 {}^1 1.

!pip install datasets==1.18.3
!pip install transformers==4.11.3
!pip install huggingface_hub==0.1
!pip install torchaudio
!pip install librosa
!pip install jiwer

Sugerimos encarecidamente cargar tus puntos de control de entrenamiento directamente en el Hugging Face Hub durante el entrenamiento. El Hugging Face Hub tiene control de versiones integrado, por lo que puedes asegurarte de que no se perderá ningún punto de control del modelo durante el entrenamiento.

Para hacerlo, debes almacenar tu token de autenticación del sitio web de Hugging Face (regístrate aquí si aún no lo has hecho).

from huggingface_hub import notebook_login

notebook_login()

Salida impresa:

    Inicio de sesión exitoso
    Tu token ha sido guardado en /root/.huggingface/token

Luego necesitas instalar Git-LFS para cargar los puntos de control de tu modelo:

apt install git-lfs

1 {}^1 1 En el artículo, el modelo fue evaluado utilizando la tasa de error de fonemas (PER), pero de lejos la métrica más común en ASR es la tasa de error de palabras (WER). Para mantener este cuaderno lo más general posible, decidimos evaluar el modelo utilizando WER.

Preparar Datos, Tokenizer, Extractor de Características

Los modelos ASR transcriben el habla a texto, lo que significa que necesitamos tanto un extractor de características que procese la señal de voz al formato de entrada del modelo, por ejemplo, un vector de características, y un tokenizer que procese el formato de salida del modelo a texto.

En 🤗 Transformers, el modelo XLS-R está acompañado por un tokenizer llamado Wav2Vec2CTCTokenizer y un extractor de características llamado Wav2Vec2FeatureExtractor.

Comencemos creando el tokenizer para decodificar las clases de salida predichas a la transcripción de salida.

Crear Wav2Vec2CTCTokenizer

Un modelo XLS-R pre-entrenado asigna la señal de voz a una secuencia de representaciones de contexto como se muestra en la figura anterior. Sin embargo, para el reconocimiento de voz, el modelo tiene que asignar esta secuencia de representaciones de contexto a su transcripción correspondiente, lo que significa que se debe agregar una capa lineal encima del bloque transformador (mostrado en amarillo en el diagrama anterior). Esta capa lineal se utiliza para clasificar cada representación de contexto en una clase de token, análogamente a cómo se agrega una capa lineal encima de las incrustaciones de BERT para una clasificación adicional después del pre-entrenamiento ( cf. con la sección ‘BERT’ del siguiente blog post ). después del pre-entrenamiento se agrega una capa lineal encima de las incrustaciones de BERT para una clasificación adicional – cf. con la sección ‘BERT’ de este blog post .

El tamaño de salida de esta capa corresponde al número de tokens en el vocabulario, que no depende de la tarea de pre-entrenamiento de XLS-R, sino solo del conjunto de datos etiquetados utilizado para el ajuste fino. Por lo tanto, en el primer paso, echaremos un vistazo al conjunto de datos elegido de Common Voice y definiremos un vocabulario basado en las transcripciones.

Primero, vayamos al sitio web oficial de Common Voice y elijamos un idioma para ajustar fino XLS-R. Para este cuaderno, usaremos el turco.

Para cada conjunto de datos específico del idioma, puedes encontrar un código de idioma correspondiente a tu idioma elegido. En Common Voice, busca el campo “Versión”. El código de idioma corresponde al prefijo antes del guión bajo. Por ejemplo, para el turco, el código de idioma es "tr".

Genial, ahora podemos utilizar la API simple de 🤗 Datasets para descargar los datos. El nombre del conjunto de datos es "common_voice", el nombre de configuración corresponde al código de idioma, que en nuestro caso es "tr".

Common Voice tiene muchos conjuntos de datos diferentes, incluido invalidated, que se refiere a datos que no se consideraron lo suficientemente “limpios” como para ser considerados útiles. En este cuaderno, solo utilizaremos los conjuntos de datos "train", "validation" y "test".

Debido a que el conjunto de datos turco es muy pequeño, fusionaremos los datos de validación y entrenamiento en un conjunto de datos de entrenamiento y solo utilizaremos los datos de prueba para la validación.

from datasets import load_dataset, load_metric, Audio

common_voice_train = load_dataset("common_voice", "tr", split="train+validation")
common_voice_test = load_dataset("common_voice", "tr", split="test")

Muchos conjuntos de datos ASR solo proporcionan el texto objetivo, 'sentence' para cada matriz de audio 'audio' y archivo 'path'. Common Voice en realidad proporciona mucha más información sobre cada archivo de audio, como el 'accent', etc. Manteniendo el cuaderno lo más general posible, solo consideramos el texto transcrito para el ajuste fino.

common_voice_train = common_voice_train.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "segment", "up_votes"])
common_voice_test = common_voice_test.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "segment", "up_votes"])

Vamos a escribir una función corta para mostrar algunas muestras aleatorias del conjunto de datos y ejecutarla varias veces para tener una idea de las transcripciones.

from datasets import ClassLabel
import random
import pandas as pd
from IPython.display import display, HTML

def mostrar_elementos_aleatorios(dataset, num_ejemplos=10):
    assert num_ejemplos <= len(dataset), "No se pueden seleccionar más elementos de los que hay en el conjunto de datos."
    selecciones = []
    for _ in range(num_ejemplos):
        seleccion = random.randint(0, len(dataset)-1)
        while seleccion in selecciones:
            seleccion = random.randint(0, len(dataset)-1)
        selecciones.append(seleccion)
    
    df = pd.DataFrame(dataset[selecciones])
    display(HTML(df.to_html()))

Resultado de impresión:

¡Bien! Las transcripciones se ven bastante limpias. Después de traducir las frases transcritas, parece que el lenguaje corresponde más a un texto escrito que a un diálogo ruidoso. Esto tiene sentido considerando que Common Voice es un corpus de habla leída de crowdsourcing.

Podemos ver que las transcripciones contienen algunos caracteres especiales, como ,.?!;:. Sin un modelo de lenguaje, es mucho más difícil clasificar fragmentos de habla en esos caracteres especiales porque no corresponden realmente a una unidad de sonido característica. Por ejemplo, la letra "s" tiene un sonido más o menos claro, mientras que el carácter especial "." no lo tiene. Además, para comprender el significado de una señal de habla, generalmente no es necesario incluir caracteres especiales en la transcripción.

Simplemente eliminemos todos los caracteres que no contribuyan al significado de una palabra y que no puedan representarse realmente mediante un sonido acústico, y normalicemos el texto.

import re
expresion_regular_caracteres_a_eliminar = '[\,\?\.\!\-\;\:\"\“\%\‘\”\�\']'

def eliminar_caracteres_especiales(batch):
    batch["frase"] = re.sub(expresion_regular_caracteres_a_eliminar, '', batch["frase"]).lower()
    return batch

common_voice_entrenamiento = common_voice_entrenamiento.map(eliminar_caracteres_especiales)
common_voice_prueba = common_voice_prueba.map(eliminar_caracteres_especiales)

Echemos otro vistazo a las etiquetas de texto procesadas.

mostrar_elementos_aleatorios(common_voice_entrenamiento.remove_columns(["ruta","audio"]))

Resultado de impresión:

¡Bien! Esto se ve mejor. Hemos eliminado la mayoría de los caracteres especiales de las transcripciones y las hemos normalizado a minúsculas solamente.

Antes de finalizar el preprocesamiento, siempre es ventajoso consultar a un hablante nativo del idioma objetivo para ver si el texto se puede simplificar aún más. Para esta publicación de blog, Merve tuvo la amabilidad de echar un vistazo rápido y señaló que los caracteres “con sombrero” – como â – ya no se usan realmente en turco y se pueden reemplazar por su equivalente “sin sombrero”, por ejemplo a.

Esto significa que debemos reemplazar una frase como "yargı sistemi hâlâ sağlıksız" por "yargı sistemi hala sağlıksız".

Escribamos otra función de mapeo corta para simplificar aún más las etiquetas de texto. Recuerda, cuanto más simples sean las etiquetas de texto, más fácil será para el modelo aprender a predecir esas etiquetas.

def reemplazar_caracteres_con_sombrero(batch):
    batch["frase"] = re.sub('[â]', 'a', batch["frase"])
    batch["frase"] = re.sub('[î]', 'i', batch["frase"])
    batch["frase"] = re.sub('[ô]', 'o', batch["frase"])
    batch["frase"] = re.sub('[û]', 'u', batch["frase"])
    return batch

common_voice_entrenamiento = common_voice_entrenamiento.map(reemplazar_caracteres_con_sombrero)
common_voice_prueba = common_voice_prueba.map(reemplazar_caracteres_con_sombrero)

En CTC, es común clasificar fragmentos de habla en letras, así que haremos lo mismo aquí. Extraremos todas las letras distintas de los datos de entrenamiento y prueba y construiremos nuestro vocabulario a partir de este conjunto de letras.

Escribimos una función de mapeo que concatena todas las transcripciones en una transcripción larga y luego convierte la cadena en un conjunto de caracteres. Es importante pasar el argumento batched=True a la función map(...) para que la función de mapeo tenga acceso a todas las transcripciones a la vez.

def extraer_todas_las_letras(batch):
  todo_texto = " ".join(batch["frase"])
  vocabulario = list(set(todo_texto))
  return {"vocabulario": [vocabulario], "todo_texto": [todo_texto]}

vocabulario_entrenamiento = common_voice_entrenamiento.map(extraer_todas_las_letras, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_entrenamiento.column_names)
vocabulario_prueba = common_voice_prueba.map(extraer_todas_las_letras, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_prueba.column_names)

Ahora, creamos la unión de todas las letras distintas en el conjunto de datos de entrenamiento y el conjunto de datos de prueba, y convertimos la lista resultante en un diccionario enumerado.

vocab_list = list(set(vocab_train["vocab"][0]) | set(vocab_test["vocab"][0]))

vocab_dict = {v: k for k, v in enumerate(sorted(vocab_list))}
vocab_dict

Resultado de la impresión:

{
 ' ': 0,
 'a': 1,
 'b': 2,
 'c': 3,
 'd': 4,
 'e': 5,
 'f': 6,
 'g': 7,
 'h': 8,
 'i': 9,
 'j': 10,
 'k': 11,
 'l': 12,
 'm': 13,
 'n': 14,
 'o': 15,
 'p': 16,
 'q': 17,
 'r': 18,
 's': 19,
 't': 20,
 'u': 21,
 'v': 22,
 'w': 23,
 'x': 24,
 'y': 25,
 'z': 26,
 'ç': 27,
 'ë': 28,
 'ö': 29,
 'ü': 30,
 'ğ': 31,
 'ı': 32,
 'ş': 33,
 '̇': 34
}

Genial, vemos que todas las letras del alfabeto aparecen en el conjunto de datos (lo cual no es sorprendente) y también hemos extraído los caracteres especiales "" y '. Note que no excluimos esos caracteres especiales porque:

El modelo debe aprender a predecir cuando una palabra ha terminado, de lo contrario la predicción del modelo siempre sería una secuencia de caracteres, lo que haría imposible separar las palabras entre sí.

Siempre se debe tener en cuenta que el preprocesamiento es un paso muy importante antes de entrenar su modelo. Por ejemplo, no queremos que nuestro modelo diferencie entre a y A solo porque olvidamos normalizar los datos. La diferencia entre a y A no depende en absoluto del “sonido” de la letra, sino más bien de las reglas gramaticales, como usar una letra en mayúscula al comienzo de la oración. Por lo tanto, tiene sentido eliminar la diferencia entre letras en mayúscula y minúscula para que el modelo tenga más facilidad para aprender a transcribir el habla.

Para dejar claro que " " tiene su propia clase de token, le damos un carácter más visible |. Además, también agregamos un token de “desconocido” para que el modelo pueda manejar caracteres que no se encuentren en el conjunto de entrenamiento de Common Voice.

vocab_dict["|"] = vocab_dict[" "]
del vocab_dict[" "]

Finalmente, también agregamos un token de relleno que corresponde al “token en blanco” de CTC. El “token en blanco” es un componente clave del algoritmo CTC. Para obtener más información, consulte la sección “Alignment” aquí .

vocab_dict["[UNK]"] = len(vocab_dict)
vocab_dict["[PAD]"] = len(vocab_dict)
len(vocab_dict)

Genial, ahora nuestro vocabulario está completo y consta de 39 tokens, lo que significa que la capa lineal que agregaremos encima del punto de control XLS-R preentrenado tendrá una dimensión de salida de 39.

Ahora guardemos el vocabulario como un archivo json.

import json
with open('vocab.json', 'w') as vocab_file:
    json.dump(vocab_dict, vocab_file)

En un paso final, usamos el archivo json para cargar el vocabulario en una instancia de la clase Wav2Vec2CTCTokenizer.

from transformers import Wav2Vec2CTCTokenizer

tokenizer = Wav2Vec2CTCTokenizer.from_pretrained("./", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")

Si desea reutilizar el tokenizador recién creado con el modelo ajustado de este cuaderno, se recomienda cargar el tokenizer en Hugging Face Hub. Llamemos al repositorio al que subiremos los archivos "wav2vec2-large-xlsr-turkish-demo-colab":

repo_name = "wav2vec2-large-xls-r-300m-tr-colab"

y carga el tokenizer en el 🤗 Hub .

tokenizer.push_to_hub(repo_name)

Genial, puedes ver el repositorio recién creado en https://huggingface.co/<your-username>/wav2vec2-large-xls-r-300m-tr-colab

Crear Wav2Vec2FeatureExtractor

El habla es una señal continua y, para ser tratada por las computadoras, primero debe ser discretizada, lo cual generalmente se llama muestreo. La tasa de muestreo juega un papel importante ya que define cuántos puntos de datos de la señal de habla se miden por segundo. Por lo tanto, muestrear con una tasa de muestreo más alta resulta en una mejor aproximación de la señal de habla real, pero también requiere más valores por segundo.

Un punto de control preentrenado espera que sus datos de entrada hayan sido muestreados más o menos de la misma distribución que los datos en los que fue entrenado. Las mismas señales de habla muestreadas a dos tasas diferentes tienen una distribución muy diferente. Por ejemplo, duplicar la tasa de muestreo hace que los puntos de datos sean el doble de largos. Por lo tanto, antes de ajustar finamente un punto de control preentrenado de un modelo ASR, es crucial verificar que la tasa de muestreo de los datos que se utilizaron para preentrenar el modelo coincida con la tasa de muestreo del conjunto de datos utilizado para ajustar finamente el modelo.

XLS-R fue preentrenado en datos de audio de Babel, Multilingual LibriSpeech (MLS), Common Voice, VoxPopuli y VoxLingua107 a una tasa de muestreo de 16 kHz. Common Voice, en su forma original, tiene una tasa de muestreo de 48 kHz, por lo que tendremos que reducir la tasa de muestreo de los datos de ajuste fino a 16 kHz a continuación.

Un objeto Wav2Vec2FeatureExtractor requiere los siguientes parámetros para ser instanciado:

  • feature_size: Los modelos de habla toman una secuencia de vectores de características como entrada. Si bien la longitud de esta secuencia obviamente varía, el tamaño de la característica no debería hacerlo. En el caso de Wav2Vec2, el tamaño de la característica es 1 porque el modelo fue entrenado en la señal de habla cruda 2 {}^2 2.
  • sampling_rate: La tasa de muestreo con la que se entrenó el modelo.
  • padding_value: Para inferencia en lotes, las entradas más cortas deben ser rellenadas con un valor específico.
  • do_normalize: Si la entrada debe normalizarse con media cero y varianza unitaria o no. Por lo general, los modelos de habla funcionan mejor cuando se normaliza la entrada.
  • return_attention_mask: Si el modelo debe utilizar una attention_mask para inferencia en lotes. En general, los puntos de control de los modelos XLS-R deben usar siempre la attention_mask.
from transformers import Wav2Vec2FeatureExtractor

feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=True)

¡Genial, el proceso de extracción de características de XLS-R está completamente definido!

Para una mayor facilidad de uso, el extractor de características y el tokenizer se envuelven en una única clase Wav2Vec2Processor para que solo se necesite un objeto model y processor.

from transformers import Wav2Vec2Processor

processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

A continuación, podemos preparar el conjunto de datos.

Preprocesar datos

Hasta ahora, no hemos examinado los valores reales de la señal de habla, sino solo la transcripción. Además de sentence, nuestros conjuntos de datos incluyen dos nombres de columna más: path y audio. path indica la ruta absoluta del archivo de audio. Echemos un vistazo.

common_voice_train[0]["path"]

XLS-R espera la entrada en el formato de una matriz unidimensional de 16 kHz. Esto significa que el archivo de audio debe cargarse y remuestrearse.

Afortunadamente, datasets hace esto automáticamente llamando a la otra columna audio. Vamos a probarlo.

common_voice_train[0]["audio"]

    {'array': array([ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00, ...,
            -8.8930130e-05, -3.8027763e-05, -2.9146671e-05], dtype=float32),
     'path': '/root/.cache/huggingface/datasets/downloads/extracted/05be0c29807a73c9b099873d2f5975dae6d05e9f7d577458a2466ecb9a2b0c6b/cv-corpus-6.1-2020-12-11/tr/clips/common_voice_tr_21921195.mp3',
     'sampling_rate': 48000}

Genial, podemos ver que el archivo de audio se ha cargado automáticamente. Esto se debe a la nueva característica “Audio” introducida en datasets == 1.18.3, que carga y re-muestrea archivos de audio al vuelo al llamarlos.

En el ejemplo anterior podemos ver que los datos de audio se cargan con una frecuencia de muestreo de 48kHz, mientras que el modelo espera 16kHz. Podemos establecer la característica de audio a la frecuencia de muestreo correcta utilizando cast_column:

common_voice_train = common_voice_train.cast_column("audio", Audio(sampling_rate=16_000))
common_voice_test = common_voice_test.cast_column("audio", Audio(sampling_rate=16_000))

Echemos otro vistazo a “audio”.

common_voice_train[0]["audio"]

    {'array': array([ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00, ...,
            -7.4556941e-05, -1.4621433e-05, -5.7861507e-05], dtype=float32),
     'path': '/root/.cache/huggingface/datasets/downloads/extracted/05be0c29807a73c9b099873d2f5975dae6d05e9f7d577458a2466ecb9a2b0c6b/cv-corpus-6.1-2020-12-11/tr/clips/common_voice_tr_21921195.mp3',
     'sampling_rate': 16000}

¡Parece que ha funcionado! Escuchemos algunos archivos de audio para comprender mejor el conjunto de datos y verificar que el audio se haya cargado correctamente.

import IPython.display as ipd
import numpy as np
import random

rand_int = random.randint(0, len(common_voice_train)-1)

print(common_voice_train[rand_int]["sentence"])
ipd.Audio(data=common_voice_train[rand_int]["audio"]["array"], autoplay=True, rate=16000)

Salida de impresión:

    sunulan bütün teklifler i̇ngilizce idi

Parece que los datos se han cargado y re-muestreado correctamente.

Se puede escuchar que los hablantes cambian junto con su velocidad de habla, acento, entorno de fondo, etc. En general, las grabaciones suenan aceptablemente claras, como cabría esperar de un corpus de habla leída creado por la comunidad.

Hagamos una verificación final de que los datos se hayan preparado correctamente, imprimiendo la forma de la entrada de voz, su transcripción y la frecuencia de muestreo correspondiente.

rand_int = random.randint(0, len(common_voice_train)-1)

print("Texto objetivo:", common_voice_train[rand_int]["sentence"])
print("Forma del arreglo de entrada:", common_voice_train[rand_int]["audio"]["array"].shape)
print("Frecuencia de muestreo:", common_voice_train[rand_int]["audio"]["sampling_rate"])

Salida de impresión:

    Texto objetivo: makedonya bu yıl otuz adet tyetmiş iki tankı aldı
    Forma del arreglo de entrada: (71040,)
    Frecuencia de muestreo: 16000

¡Bien! Todo parece estar bien: los datos son un arreglo unidimensional, la frecuencia de muestreo siempre corresponde a 16kHz y el texto objetivo está normalizado.

Finalmente, podemos aprovechar Wav2Vec2Processor para procesar los datos al formato esperado por Wav2Vec2ForCTC para el entrenamiento. Para hacerlo, usemos la función map(…) de Dataset.

Primero, cargamos y remuestreamos los datos de audio, simplemente llamando a batch["audio"]. Segundo, extraemos los input_values del archivo de audio cargado. En nuestro caso, el Wav2Vec2Processor solo normaliza los datos. Sin embargo, para otros modelos de habla, este paso puede incluir una extracción de características más compleja, como la extracción de características Log-Mel. Tercero, codificamos las transcripciones en identificadores de etiqueta.

Nota: Esta función de mapeo es un buen ejemplo de cómo se debe usar la clase Wav2Vec2Processor. En un contexto “normal”, llamar a processor(...) se redirige al método de llamada de Wav2Vec2FeatureExtractor. Sin embargo, al envolver el procesador en el contexto as_target_processor, el mismo método se redirige al método de llamada de Wav2Vec2CTCTokenizer. Para obtener más información, consulte la documentación.

def prepare_dataset(batch):
    audio = batch["audio"]

    # el resultado agrupado no tiene batch
    batch["input_values"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]
    batch["input_length"] = len(batch["input_values"])
    
    with processor.as_target_processor():
        batch["labels"] = processor(batch["sentence"]).input_ids
    return batch

Apliquemos la función de preparación de datos a todos los ejemplos.

common_voice_train = common_voice_train.map(prepare_dataset, remove_columns=common_voice_train.column_names)
common_voice_test = common_voice_test.map(prepare_dataset, remove_columns=common_voice_test.column_names)

Nota: Actualmente, datasets utiliza torchaudio y librosa para cargar y remuestrear audio. Si desea implementar su propio proceso de carga/remuestreo de datos personalizado, simplemente use la columna "path" y no tenga en cuenta la columna "audio".

Las secuencias de entrada largas requieren mucha memoria. XLS-R se basa en la autoatención. El requisito de memoria aumenta cuadráticamente con la longitud de entrada para secuencias de entrada largas (cf. este post en Reddit). En caso de que esta demostración se bloquee con un error de “Memoria insuficiente”, puede descomentar las siguientes líneas para filtrar todas las secuencias que sean más largas de 5 segundos para el entrenamiento.

#max_input_length_in_sec = 5.0
#common_voice_train = common_voice_train.filter(lambda x: x < max_input_length_in_sec * processor.feature_extractor.sampling_rate, input_columns=["input_length"])

¡Increíble, ahora estamos listos para comenzar el entrenamiento!

Entrenamiento

Los datos se procesan para que estemos listos para comenzar a configurar el pipeline de entrenamiento. Utilizaremos el Trainer de 🤗, para lo cual esencialmente necesitamos hacer lo siguiente:

  • Definir un recolector de datos. A diferencia de la mayoría de los modelos de procesamiento del lenguaje natural, XLS-R tiene una longitud de entrada mucho mayor que la longitud de salida. Por ejemplo, una muestra de longitud de entrada 50000 tiene una longitud de salida de no más de 100. Dadas las grandes tamaños de entrada, es mucho más eficiente rellenar los lotes de entrenamiento de forma dinámica, lo que significa que todas las muestras de entrenamiento solo deben rellenarse hasta la muestra más larga en su lote y no la muestra más larga en general. Por lo tanto, el ajuste fino de XLS-R requiere un recolector de datos de relleno especial, que definiremos a continuación

  • Métrica de evaluación. Durante el entrenamiento, el modelo debe evaluarse en la tasa de error de palabras. Deberemos definir una función compute_metrics en consecuencia

  • Cargar un punto de control preentrenado. Necesitamos cargar un punto de control preentrenado y configurarlo correctamente para el entrenamiento.

  • Definir la configuración de entrenamiento.

Después de haber ajustado finamente el modelo, lo evaluaremos correctamente en los datos de prueba y verificaremos que realmente haya aprendido a transcribir el habla correctamente.

Configurar el Trainer

Comencemos definiendo el recolector de datos. El código del recolector de datos se copió de este ejemplo.

Sin entrar en demasiados detalles, a diferencia de los recolectores de datos comunes, este recolector de datos trata los input_values y labels de manera diferente y, por lo tanto, aplica funciones de relleno separadas en ellos (nuevamente haciendo uso del administrador de contexto del procesador XLS-R). Esto es necesario porque en el habla, la entrada y la salida son de diferentes modalidades, lo que significa que no deben ser tratadas por la misma función de relleno. De manera análoga a los recolectores de datos comunes, los tokens de relleno en las etiquetas tienen un valor de -100 para que esos tokens no se tengan en cuenta al calcular la pérdida.

import torch

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union

@dataclass
class DataCollatorCTCWithPadding:
    """
    Data collator that will dynamically pad the inputs received.
    Args:
        processor (:class:`~transformers.Wav2Vec2Processor`)
            The processor used for processing the data.
        padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
            Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
            among:
            * :obj:`True` or :obj:`'longest'`: Rellena hasta la secuencia más larga en el lote (o no rellena si solo hay una secuencia).
            * :obj:`'max_length'`: Rellena hasta una longitud máxima especificada con el argumento :obj:`max_length` o hasta la longitud de entrada máxima aceptable para el modelo si no se proporciona ese argumento.
            * :obj:`False` or :obj:`'do_not_pad'` (default): Sin relleno (es decir, puede generar un lote con secuencias de diferentes longitudes).
    """

    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # divide las entradas y las etiquetas ya que deben tener longitudes diferentes y necesitan
        # diferentes métodos de relleno
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            return_tensors="pt",
        )
        with self.processor.as_target_processor():
            labels_batch = self.processor.pad(
                label_features,
                padding=self.padding,
                return_tensors="pt",
            )

        # reemplaza el relleno con -100 para ignorar correctamente la pérdida
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        batch["labels"] = labels

        return batch

data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

A continuación, se define la métrica de evaluación. Como se mencionó anteriormente, la métrica predominante en ASR es la tasa de error de palabras (WER), por lo tanto, también la utilizaremos en este cuaderno.

wer_metric = load_metric("wer")

El modelo devolverá una secuencia de vectores de logitos: y 1 , … , y m \mathbf{y}_1, \ldots, \mathbf{y}_m y 1 ​ , … , y m ​ con y 1 = f θ ( x 1 , … , x n ) [ 0 ] \mathbf{y}_1 = f_{\theta}(x_1, \ldots, x_n)[0] y 1 ​ = f θ ​ ( x 1 ​ , … , x n ​ ) [ 0 ] y n > > m n >> m n > > m .

Un vector de logitos y 1 \mathbf{y}_1 y 1 ​ contiene las probabilidades logarítmicas para cada palabra en el vocabulario que definimos anteriormente, por lo tanto len ( y i ) = \text{len}(\mathbf{y}_i) = len ( y i ​ ) = config.vocab_size . Estamos interesados en la predicción más probable del modelo y por lo tanto tomamos argmax(...) de los logitos. Además, transformamos las etiquetas codificadas de vuelta a la cadena original reemplazando -100 con pad_token_id y decodificando los ids asegurándonos de que los tokens consecutivos no se agrupen en el mismo token en el estilo CTC 1 {}^1 1 .

def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids)
    # no queremos agrupar tokens al calcular las métricas
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

Ahora, podemos cargar el punto de control preentrenado de Wav2Vec2-XLS-R-300M . El pad_token_id del tokenizador debe ser definido como el pad_token_id del modelo o en el caso de Wav2Vec2ForCTC también el token en blanco de CTC 2 {}^2 2 . Para ahorrar memoria de la GPU, habilitamos el checkpointing de gradientes de PyTorch y también configuramos la reducción de pérdida como ” mean “.

Debido a que el conjunto de datos es bastante pequeño (~6h de datos de entrenamiento) y debido a que Common Voice es bastante ruidoso, parece que se requiere ajustar algunos hiperparámetros para afinar el punto de control wav2vec2-xls-r-300m de Facebook. Por lo tanto, tuve que probar diferentes valores para la eliminación aleatoria, la tasa de eliminación de enmascaramiento de SpecAugment, la eliminación de capa y la tasa de aprendizaje hasta que el entrenamiento pareciera ser lo suficientemente estable.

Nota: Al usar este cuaderno para entrenar XLS-R en otro idioma de Common Voice, es posible que esos ajustes de hiperparámetros no funcionen muy bien. Siéntete libre de adaptarlos según tus necesidades.

from transformers import Wav2Vec2ForCTC

modelo = Wav2Vec2ForCTC.from_pretrained(
    "facebook/wav2vec2-xls-r-300m", 
    attention_dropout=0.0,
    hidden_dropout=0.0,
    feat_proj_dropout=0.0,
    mask_time_prob=0.05,
    layerdrop=0.0,
    ctc_loss_reduction="mean", 
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer),
)

El primer componente de XLS-R consta de una pila de capas CNN que se utilizan para extraer características acústicamente significativas, pero independientes contextualmente, de la señal de voz en bruto. Esta parte del modelo ya ha sido suficientemente entrenada durante la preentrenamiento y, como se indica en el artículo, no es necesario afinarla más. Por lo tanto, podemos establecer requires_grad en False para todos los parámetros de la parte de extracción de características.

modelo.freeze_feature_extractor()

En un paso final, definimos todos los parámetros relacionados con el entrenamiento. Para dar más explicación sobre algunos de los parámetros:

  • group_by_length hace que el entrenamiento sea más eficiente agrupando las muestras de entrenamiento de longitud de entrada similar en un lote. Esto puede acelerar significativamente el tiempo de entrenamiento al reducir en gran medida el número total de tokens de relleno inútiles que pasan por el modelo
  • learning_rate y weight_decay se ajustaron heurísticamente hasta que el ajuste fino se volvió estable. Ten en cuenta que esos parámetros dependen en gran medida del conjunto de datos de Common Voice y pueden no ser óptimos para otros conjuntos de datos de voz.

Para obtener más explicaciones sobre otros parámetros, se puede consultar la documentación.

Durante el entrenamiento, se cargará un punto de control de forma asincrónica en el Hub cada 400 pasos de entrenamiento. Esto te permite jugar con el widget de demostración incluso mientras tu modelo aún se está entrenando.

Nota: Si no se desea cargar los puntos de control del modelo en el Hub, simplemente establece push_to_hub=False.

from transformers import TrainingArguments

training_args = TrainingArguments(
  output_dir=nombre_repo,
  group_by_length=True,
  per_device_train_batch_size=16,
  gradient_accumulation_steps=2,
  evaluation_strategy="steps",
  num_train_epochs=30,
  gradient_checkpointing=True,
  fp16=True,
  save_steps=400,
  eval_steps=400,
  logging_steps=400,
  learning_rate=3e-4,
  warmup_steps=500,
  save_total_limit=2,
  push_to_hub=True,
)

Ahora, todas las instancias se pueden pasar al Entrenador y estamos listos para comenzar el entrenamiento.

from transformers import Trainer

entrenador = Trainer(
    model=modelo,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=common_voice_train,
    eval_dataset=common_voice_test,
    tokenizer=processor.feature_extractor,
)

1 {}^1 1 Para permitir que los modelos sean independientes de la velocidad del hablante, en CTC, los tokens consecutivos que son idénticos se agrupan simplemente como un solo token. Sin embargo, las etiquetas codificadas no deben agruparse al decodificar ya que no corresponden a los tokens predichos por el modelo, por lo que se debe pasar el parámetro group_tokens=False. Si no pasáramos este parámetro, una palabra como "hello" se codificaría incorrectamente y se decodificaría como "helo". 2 {}^2 2 El token en blanco permite que el modelo prediga una palabra, como "hello", al obligarlo a insertar el token en blanco entre las dos “l”. Una predicción conforme a CTC de "hello" de nuestro modelo sería [PAD] [PAD] "h" "e" "e" "l" "l" [PAD] "l" "o" "o" [PAD].

Entrenamiento

El entrenamiento tomará varias horas dependiendo de la GPU asignada a este cuaderno. Si bien el modelo entrenado produce resultados algo satisfactorios en los datos de prueba de Turkish de Common Voice, no es en modo alguno un modelo optimizado y ajustado de manera óptima. El propósito de este cuaderno es simplemente demostrar cómo ajustar finamente XLS-R XLSR-Wav2Vec2 en un conjunto de datos de ASR.

Dependiendo de la GPU asignada a su google colab, es posible que vea un error de "sin memoria" aquí. En este caso, probablemente sea mejor reducir per_device_train_batch_size a 8 o incluso menos y aumentar gradient_accumulation.

trainer.train()

Resultado de impresión:

La pérdida de entrenamiento y el WER de validación disminuyen adecuadamente.

Ahora puede cargar el resultado del entrenamiento en el Hub, solo ejecute esta instrucción:

trainer.push_to_hub()

Ahora puede compartir este modelo con todos sus amigos, familiares, mascotas favoritas: todos pueden cargarlo con el identificador “nombre-de-usuario/el-nombre-que-eligió”, por ejemplo:

from transformers import AutoModelForCTC, Wav2Vec2Processor

model = AutoModelForCTC.from_pretrained("patrickvonplaten/wav2vec2-large-xls-r-300m-tr-colab")
processor = Wav2Vec2Processor.from_pretrained("patrickvonplaten/wav2vec2-large-xls-r-300m-tr-colab")

Para obtener más ejemplos de cómo se puede ajustar finamente XLS-R, consulte los ejemplos oficiales de 🤗 Transformers.

Evaluación

Como última comprobación, carguemos el modelo y verifiquemos que realmente haya aprendido a transcribir el habla turca.

Primero carguemos el punto de control preentrenado.

model = Wav2Vec2ForCTC.from_pretrained(repo_name).to("cuda")
processor = Wav2Vec2Processor.from_pretrained(repo_name)

Ahora, simplemente tomaremos el primer ejemplo del conjunto de prueba, lo ejecutaremos a través del modelo y tomaremos el argmax(...) de los logits para recuperar los identificadores de token predichos.

input_dict = processor(common_voice_test[0]["input_values"], return_tensors="pt", padding=True)

logits = model(input_dict.input_values.to("cuda")).logits

pred_ids = torch.argmax(logits, dim=-1)[0]

Se recomienda encarecidamente pasar el argumento sampling_rate a esta función. No hacerlo puede resultar en errores silenciosos que pueden ser difíciles de depurar.

Adaptamos common_voice_test bastante para que la instancia del conjunto de datos ya no contenga la etiqueta de la oración original. Por lo tanto, reutilizamos el conjunto de datos original para obtener la etiqueta del primer ejemplo.

common_voice_test_transcription = load_dataset("common_voice", "tr", data_dir="./cv-corpus-6.1-2020-12-11", split="test")

Finalmente, podemos decodificar el ejemplo.

print("Predicción:")
print(processor.decode(pred_ids))

print("\nReferencia:")
print(common_voice_test_transcription[0]["sentence"].lower())

Resultado de impresión:

¡Bien! La transcripción definitivamente se puede reconocer en nuestra predicción, pero aún no es perfecta. Entrenar el modelo un poco más, dedicar más tiempo al preprocesamiento de datos y especialmente usar un modelo de lenguaje para la decodificación seguramente mejorarían el rendimiento general del modelo.

Para un modelo de demostración en un idioma de recursos limitados, sin embargo, los resultados son bastante aceptables 🤗.