Ajustar Wav2Vec2 para ASR en inglés en Hugging Face con 🤗 Transformers
'Ajustar Wav2Vec2 en Hugging Face para ASR en inglés con 🤗 Transformers'
Wav2Vec2 es un modelo preentrenado para el Reconocimiento Automático del Habla (ASR) y fue lanzado en septiembre de 2020 por Alexei Baevski, Michael Auli y Alex Conneau.
Usando un objetivo de preentrenamiento contrastivo novedoso, Wav2Vec2 aprende representaciones de habla poderosas a partir de más de 50.000 horas de habla no etiquetada. Similar al modelo de lenguaje enmascarado de BERT, el modelo aprende representaciones de habla contextualizadas al enmascarar aleatoriamente vectores de caracterÃsticas antes de pasarlos a una red transformadora.
Por primera vez, se ha demostrado que el preentrenamiento, seguido del ajuste fino con muy pocos datos de habla etiquetados, logra resultados competitivos en comparación con los sistemas de ASR de última generación. Utilizando tan solo 10 minutos de datos etiquetados, Wav2Vec2 logra una tasa de error de palabras (WER) de menos del 5% en el conjunto de prueba limpio de LibriSpeech – ver Tabla 9 del artÃculo.
- Mi viaje hacia una tuberÃa de transformadores sin servidor en Googl...
- La colaboración Amazon SageMaker y Hugging Face
- Presentamos 🤗 Accelerate
En este cuaderno, daremos una explicación detallada de cómo los puntos de control preentrenados de Wav2Vec2 se pueden ajustar finamente en cualquier conjunto de datos de ASR en inglés. Tenga en cuenta que en este cuaderno, ajustaremos finamente Wav2Vec2 sin hacer uso de un modelo de lenguaje. Es mucho más simple utilizar Wav2Vec2 sin un modelo de lenguaje como un sistema de ASR de extremo a extremo y se ha demostrado que un modelo acústico independiente de Wav2Vec2 logra resultados impresionantes. Con fines de demostración, ajustaremos finamente el punto de control preentrenado de tamaño “base” en el conjunto de datos Timit, que es bastante pequeño y contiene solo 5 horas de datos de entrenamiento.
Wav2Vec2 se ajusta finamente utilizando la Clasificación Temporal Conexionista (CTC), que es un algoritmo utilizado para entrenar redes neuronales en problemas de secuencia a secuencia, principalmente en el reconocimiento automático del habla y el reconocimiento de escritura a mano.
Recomiendo leer el artÃculo del blog “Sequence Modeling with CTC (2017)” muy bien escrito por Awni Hannun.
Antes de comenzar, instalemos tanto datasets
como transformers
desde master. Además, necesitamos el paquete soundfile
para cargar archivos de audio y jiwer
para evaluar nuestro modelo ajustado finamente utilizando la métrica de tasa de error de palabras (WER) 1 {}^1 1 .
!pip install datasets>=1.18.3
!pip install transformers==4.11.3
!pip install librosa
!pip install jiwer
A continuación, sugerimos encarecidamente cargar sus puntos de control de entrenamiento directamente en el Hugging Face Hub mientras entrena. El Hub tiene un control de versiones integrado para asegurarse de que no se pierda ningún punto de control del modelo durante el entrenamiento.
Para hacer esto, debe almacenar su token de autenticación del sitio web de Hugging Face (regÃstrese aquà si aún no lo ha hecho)
from huggingface_hub import notebook_login
notebook_login()
Resultado de impresión:
Inicio de sesión exitoso
Su token se ha guardado en /root/.huggingface/token
Autenticado a través de la tienda de credenciales git, pero este no es el asistente definido en su máquina.
Tendrá que volver a autenticarse al enviar a Hugging Face Hub. Ejecute el siguiente comando en su terminal para establecerlo como el valor predeterminado
git config --global credential.helper store
Luego, debe instalar Git-LFS para cargar los puntos de control de su modelo:
!apt install git-lfs
1 {}^1 1 TÃpicamente, Timit se evalúa utilizando la tasa de error de fonemas (PER), pero con mucho, la métrica más común en ASR es la tasa de error de palabras (WER). Para mantener este cuaderno lo más general posible, decidimos evaluar el modelo utilizando WER.
Preparar Datos, Tokenizador, Extractor de CaracterÃsticas
Los modelos de ASR transcriben el habla a texto, lo que significa que necesitamos tanto un extractor de caracterÃsticas que procese la señal de habla al formato de entrada del modelo, por ejemplo, un vector de caracterÃsticas, como un tokenizador que procese el formato de salida del modelo a texto.
En 🤗 Transformers, el modelo Wav2Vec2 está acompañado tanto por un tokenizador, llamado Wav2Vec2CTCTokenizer , como por un extractor de caracterÃsticas, llamado Wav2Vec2FeatureExtractor .
Comencemos creando el tokenizador responsable de decodificar las predicciones del modelo.
Crear Wav2Vec2CTCTokenizer
El punto de control preentrenado de Wav2Vec2 mapea la señal de voz a una secuencia de representaciones de contexto como se muestra en la figura anterior. Un punto de control fine-tuned de Wav2Vec2 necesita mapear esta secuencia de representaciones de contexto a su transcripción correspondiente para que se deba agregar una capa lineal encima del bloque transformador (mostrado en amarillo). Esta capa lineal se utiliza para clasificar cada representación de contexto en una clase de token análoga a cómo, por ejemplo, después del preentrenamiento se agrega una capa lineal encima de las incrustaciones de BERT para una clasificación adicional – ver sección “BERT” de esta publicación de blog.
El tamaño de salida de esta capa corresponde al número de tokens en el vocabulario, que no depende de la tarea de preentrenamiento de Wav2Vec2, sino solo del conjunto de datos etiquetados utilizado para el ajuste fino. Por lo tanto, en el primer paso, echaremos un vistazo a Timit y definiremos un vocabulario basado en las transcripciones del conjunto de datos.
Comencemos cargando el conjunto de datos y echando un vistazo a su estructura.
from datasets import load_dataset, load_metric
timit = load_dataset("timit_asr")
print(timit)
Resultado de impresión:
DatasetDict({
train: Dataset({
features: ['file', 'audio', 'text', 'phonetic_detail', 'word_detail', 'dialect_region', 'sentence_type', 'speaker_id', 'id'],
num_rows: 4620
})
test: Dataset({
features: ['file', 'audio', 'text', 'phonetic_detail', 'word_detail', 'dialect_region', 'sentence_type', 'speaker_id', 'id'],
num_rows: 1680
})
})
Muchos conjuntos de datos ASR solo proporcionan el texto objetivo, 'text'
para cada archivo de audio 'file'
. Timit proporciona en realidad mucha más información sobre cada archivo de audio, como 'phonetic_detail'
, etc., por eso muchos investigadores eligen evaluar sus modelos en la clasificación de fonemas en lugar del reconocimiento de voz cuando trabajan con Timit. Sin embargo, queremos mantener el cuaderno lo más general posible, por lo que solo consideraremos el texto transcrito para el ajuste fino.
timit = timit.remove_columns(["phonetic_detail", "word_detail", "dialect_region", "id", "sentence_type", "speaker_id"])
Escribamos una función corta para mostrar algunos ejemplos aleatorios del conjunto de datos y ejecutémosla varias veces para tener una idea de las transcripciones.
from datasets import ClassLabel
import random
import pandas as pd
from IPython.display import display, HTML
def show_random_elements(dataset, num_examples=10):
assert num_examples <= len(dataset), "No se pueden seleccionar más elementos de los que hay en el conjunto de datos."
picks = []
for _ in range(num_examples):
pick = random.randint(0, len(dataset)-1)
while pick in picks:
pick = random.randint(0, len(dataset)-1)
picks.append(pick)
df = pd.DataFrame(dataset[picks])
display(HTML(df.to_html()))
show_random_elements(timit["train"].remove_columns(["file", "audio"]))
Resultado de impresión:
¡Genial! Las transcripciones se ven muy limpias y el lenguaje parece corresponder más al texto escrito que al diálogo. Esto tiene sentido teniendo en cuenta que Timit es un corpus de habla leÃda.
Podemos ver que las transcripciones contienen algunos caracteres especiales, como ,.?!;:
. Sin un modelo de lenguaje, es mucho más difÃcil clasificar trozos de voz a esos caracteres especiales porque en realidad no corresponden a una unidad de sonido caracterÃstica. Por ejemplo, la letra "s"
tiene un sonido más o menos claro, mientras que el carácter especial "."
no lo tiene. Además, para comprender el significado de una señal de voz, generalmente no es necesario incluir caracteres especiales en la transcripción.
Además, normalizamos el texto para que solo tenga letras en minúscula.
import re
chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"]'
def remove_special_characters(batch):
batch["text"] = re.sub(chars_to_ignore_regex, '', batch["text"]).lower()
return batch
timit = timit.map(remove_special_characters)
Echemos un vistazo a las transcripciones preprocesadas.
mostrar_elementos_aleatorios(timit["train"].remover_columnas(["file", "audio"]))
Resultado de la impresión:
¡Bien! Esto se ve mejor. Hemos eliminado la mayorÃa de los caracteres especiales de las transcripciones y las hemos normalizado a minúsculas solamente.
En CTC, es común clasificar fragmentos de habla en letras, asà que haremos lo mismo aquÃ. Vamos a extraer todas las letras distintas de los datos de entrenamiento y prueba y construir nuestro vocabulario a partir de este conjunto de letras.
Escribimos una función de mapeo que concatena todas las transcripciones en una sola transcripción larga y luego transforma la cadena en un conjunto de caracteres. Es importante pasar el argumento batched=True
a la función map(...)
para que la función de mapeo tenga acceso a todas las transcripciones a la vez.
def extraer_todos_los_caracteres(lote):
todo_texto = " ".join(lote["texto"])
vocabulario = list(set(todo_texto))
return {"vocabulario": [vocabulario], "todo_texto": [todo_texto]}
vocabs = timit.map(extraer_todos_los_caracteres, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=timit.column_names["train"])
Ahora, creamos la unión de todas las letras distintas en el conjunto de datos de entrenamiento y el conjunto de datos de prueba y convertimos la lista resultante en un diccionario enumerado.
lista_vocabulario = list(set(vocabs["train"]["vocabulario"][0]) | set(vocabs["test"]["vocabulario"][0]))
diccionario_vocabulario = {v: k for k, v in enumerate(lista_vocabulario)}
diccionario_vocabulario
Resultado de la impresión:
{
' ': 21,
"'": 13,
'a': 24,
'b': 17,
'c': 25,
'd': 2,
'e': 9,
'f': 14,
'g': 22,
'h': 8,
'i': 4,
'j': 18,
'k': 5,
'l': 16,
'm': 6,
'n': 7,
'o': 10,
'p': 19,
'q': 3,
'r': 20,
's': 11,
't': 0,
'u': 26,
'v': 27,
'w': 1,
'x': 23,
'y': 15,
'z': 12
}
Genial, vemos que todas las letras del alfabeto aparecen en el conjunto de datos (lo cual no es sorprendente) y también hemos extraÃdo los caracteres especiales " "
y '
. Nota que no excluimos esos caracteres especiales porque:
- El modelo tiene que aprender a predecir cuando una palabra termina, de lo contrario, la predicción del modelo siempre serÃa una secuencia de caracteres, lo que harÃa imposible separar las palabras unas de otras.
- En inglés, necesitamos mantener el carácter
'
para diferenciar entre palabras, por ejemplo,"it's"
y"its"
que tienen significados muy diferentes.
Para dejar más claro que " "
tiene su propia clase de token, le asignamos un carácter más visible |
. Además, también agregamos un token “desconocido” para que el modelo pueda manejar posteriormente caracteres que no se encuentren en el conjunto de entrenamiento de Timit.
diccionario_vocabulario["|"] = diccionario_vocabulario[" "]
del diccionario_vocabulario[" "]
Finalmente, también agregamos un token de relleno que corresponde al “token en blanco” de CTC. El “token en blanco” es un componente fundamental del algoritmo CTC. Para obtener más información, por favor echa un vistazo a la sección de “Alineación” aquÃ.
diccionario_vocabulario["[UNK]"] = len(diccionario_vocabulario)
diccionario_vocabulario["[PAD]"] = len(diccionario_vocabulario)
print(len(diccionario_vocabulario))
Resultado de la impresión:
30
Genial, ahora nuestro vocabulario está completo y consta de 30 tokens, lo que significa que la capa lineal que agregaremos encima del punto de control preentrenado de Wav2Vec2 tendrá una dimensión de salida de 30.
Guardemos ahora el vocabulario como un archivo json.
import json
with open('vocab.json', 'w') as vocab_file:
json.dump(vocab_dict, vocab_file)
En un paso final, usamos el archivo json para instanciar un objeto de la clase Wav2Vec2CTCTokenizer
.
from transformers import Wav2Vec2CTCTokenizer
tokenizer = Wav2Vec2CTCTokenizer("./vocab.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")
Si alguien desea reutilizar el tokenizer recién creado con el modelo ajustado de este cuaderno, se recomienda encarecidamente cargar el tokenizer
en el 🤗 Hub. Llamemos al repositorio al que subiremos los archivos "wav2vec2-large-xlsr-turkish-demo-colab"
:
repo_name = "wav2vec2-base-timit-demo-colab"
y carga el tokenizer en el 🤗 Hub.
tokenizer.push_to_hub(repo_name)
Genial, puedes ver el repositorio recién creado en https://huggingface.co/<tu-nombre-de-usuario>/wav2vec2-base-timit-demo-colab
Crear Extractor de CaracterÃsticas Wav2Vec2
El habla es una señal continua y para que las computadoras la traten, primero debe ser discretizada, lo que generalmente se llama muestreo. La tasa de muestreo juega un papel importante en que define cuántos puntos de datos de la señal de habla se miden por segundo. Por lo tanto, el muestreo con una tasa de muestreo más alta resulta en una mejor aproximación de la señal de habla real pero también requiere más valores por segundo.
Un punto de control preentrenado espera que sus datos de entrada hayan sido muestreados más o menos desde la misma distribución que los datos en los que se entrenó. Las mismas señales de habla muestreadas a dos tasas diferentes tienen una distribución muy diferente, por ejemplo, duplicar la tasa de muestreo hace que los puntos de datos sean el doble de largos. Por lo tanto, antes de ajustar finamente un punto de control preentrenado de un modelo ASR, es crucial verificar que la tasa de muestreo de los datos que se utilizaron para preentrenar el modelo coincida con la tasa de muestreo del conjunto de datos utilizado para ajustar finamente el modelo.
Wav2Vec2 fue preentrenado en los datos de audio de LibriSpeech y LibriVox, que ambos se muestrearon con 16kHz. Afortunadamente, nuestro conjunto de datos de ajuste fino, Timit, también se muestreó con 16kHz. Si el conjunto de datos de ajuste fino se hubiera muestreado con una tasa menor o mayor que 16kHz, primero habrÃamos tenido que aumentar o disminuir la frecuencia de muestreo de la señal de habla para que coincidiera con la tasa de muestreo de los datos utilizados para el preentrenamiento.
Un objeto extractor de caracterÃsticas Wav2Vec2 requiere los siguientes parámetros para ser instanciado:
feature_size
: Los modelos de habla toman una secuencia de vectores de caracterÃsticas como entrada. Si bien la longitud de esta secuencia obviamente varÃa, el tamaño de la caracterÃstica no deberÃa. En el caso de Wav2Vec2, el tamaño de la caracterÃstica es 1 porque el modelo se entrenó con la señal de habla cruda 2 {}^2 2.sampling_rate
: La tasa de muestreo en la que se entrenó el modelo.padding_value
: Para la inferencia por lotes, las entradas más cortas deben rellenarse con un valor especÃfico.do_normalize
: Si la entrada debe normalizarse a cero media y varianza unitaria o no. Por lo general, los modelos de habla funcionan mejor cuando se normaliza la entrada.return_attention_mask
: Si el modelo debe utilizar unaattention_mask
para la inferencia por lotes. En general, los modelos deben siempre usar laattention_mask
para enmascarar los tokens rellenados. Sin embargo, debido a una elección de diseño muy especÃfica del punto de control “base” deWav2Vec2
, se obtienen mejores resultados cuando no se utiliza unaattention_mask
. Esto no se recomienda para otros modelos de habla. Para obtener más información, se puede consultar este problema. Importante Si desea usar este cuaderno para ajustar finamente large-lv60, este parámetro debe establecerse enTrue
.
from transformers import Wav2Vec2FeatureExtractor
feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=False)
¡Genial, la tuberÃa de extracción de caracterÃsticas de Wav2Vec2 está completamente definida!
Para hacer que el uso de Wav2Vec2 sea lo más amigable posible para el usuario, el extractor de caracterÃsticas y el tokenizador se envuelven en una sola clase Wav2Vec2Processor
para que solo se necesite un objeto modelo
y procesador
.
from transformers import Wav2Vec2Processor
procesador = Wav2Vec2Processor(extractor_de_caracterÃsticas=extractor_de_caracterÃsticas, tokenizador=tokenizador)
Preprocesar datos
Hasta ahora, no hemos examinado los valores reales de la señal de habla, sino solo la transcripción. Además de la oración, nuestros conjuntos de datos incluyen dos columnas más llamadas ruta y audio. La ruta indica la ruta absoluta del archivo de audio. Echemos un vistazo.
print(timit[0]["ruta"])
Salida de impresión:
'/root/.cache/huggingface/datasets/downloads/extracted/404950a46da14eac65eb4e2a8317b1372fb3971d980d91d5d5b221275b1fd7e0/data/TRAIN/DR4/MMDM0/SI681.WAV'
Wav2Vec2
espera la entrada en el formato de una matriz unidimensional de 16 kHz. Esto significa que el archivo de audio debe cargarse y remuestrearse.
Afortunadamente, datasets hace esto automáticamente llamando a la otra columna audio. Probémoslo.
common_voice_train[0]["audio"]
Salida de impresión:
{'array': array([-2.1362305e-04, 6.1035156e-05, 3.0517578e-05, ...,
-3.0517578e-05, -9.1552734e-05, -6.1035156e-05], dtype=float32),
'ruta': '/root/.cache/huggingface/datasets/downloads/extracted/404950a46da14eac65eb4e2a8317b1372fb3971d980d91d5d5b221275b1fd7e0/data/TRAIN/DR4/MMDM0/SI681.WAV',
'tasa_de_muestreo': 16000}
Podemos ver que el archivo de audio se ha cargado automáticamente. Esto se debe a la nueva "caracterÃstica de audio"
introducida en datasets == 4.13.3, que carga y remuestrea archivos de audio sobre la marcha al llamarlos.
La tasa de muestreo se establece en 16kHz, que es lo que Wav2Vec2
espera como entrada.
¡Genial, escuchemos un par de archivos de audio para comprender mejor el conjunto de datos y verificar que el audio se haya cargado correctamente!
import IPython.display as ipd
import numpy as np
import random
rand_int = random.randint(0, len(timit["train"]))
print(timit["train"][rand_int]["text"])
ipd.Audio(data=np.asarray(timit["train"][rand_int]["audio"]["array"]), autoplay=True, rate=16000)
Se puede escuchar que los hablantes cambian junto con su velocidad de habla, acento, etc. En general, las grabaciones suenan relativamente claras, como se esperarÃa de un corpus de habla leÃda.
Hagamos una verificación final de que los datos estén preparados correctamente, imprimiendo la forma de la entrada de habla, su transcripción y la tasa de muestreo correspondiente.
rand_int = random.randint(0, len(timit["train"]))
print("Texto objetivo:", timit["train"][rand_int]["text"])
print("Forma del arreglo de entrada:", np.asarray(timit["train"][rand_int]["audio"]["array"]).shape)
print("Tasa de muestreo:", timit["train"][rand_int]["audio"]["sampling_rate"])
Salida de impresión:
Texto objetivo: she had your dark suit in greasy wash water all year
Forma del arreglo de entrada: (52941,)
Tasa de muestreo: 16000
¡Bien! Todo parece estar bien: los datos son una matriz unidimensional, la tasa de muestreo siempre corresponde a 16kHz y el texto objetivo está normalizado.
Finalmente, podemos procesar el conjunto de datos al formato esperado por el modelo para el entrenamiento. Vamos a utilizar la función map(...)
.
Primero, cargamos y re-muestreamos los datos de audio, simplemente llamando a batch["audio"]
. Segundo, extraemos los input_values
del archivo de audio cargado. En nuestro caso, el Wav2Vec2Processor
solo normaliza los datos. Sin embargo, para otros modelos de habla, este paso puede incluir una extracción de caracterÃsticas más compleja, como la extracción de caracterÃsticas Log-Mel. Tercero, codificamos las transcripciones en identificadores de etiquetas.
Nota: Esta función de mapeo es un buen ejemplo de cómo se debe utilizar la clase Wav2Vec2Processor
. En un contexto “normal”, llamar a processor(...)
se redirige al método de llamada de Wav2Vec2FeatureExtractor
. Sin embargo, al envolver el procesador en el contexto as_target_processor
, el mismo método se redirige al método de llamada de Wav2Vec2CTCTokenizer
. Para obtener más información, consulte la documentación.
def prepare_dataset(batch):
audio = batch["audio"]
# La salida por lotes se "desagrupa" para asegurar que el mapeo sea correcto
batch["input_values"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]
with processor.as_target_processor():
batch["labels"] = processor(batch["text"]).input_ids
return batch
Apliquemos la función de preparación de datos a todos los ejemplos.
timit = timit.map(prepare_dataset, remove_columns=timit.column_names["train"], num_proc=4)
Nota: Actualmente, los datasets
utilizan torchaudio
y librosa
para cargar y remuestrear audio. Si desea implementar su propia carga/remuestreo de datos personalizada, simplemente utilice la columna "path"
y no utilice la columna "audio"
.
Entrenamiento y Evaluación
Los datos se procesan para que estemos listos para comenzar a configurar el flujo de entrenamiento. Utilizaremos el Trainer de 🤗, para lo cual básicamente necesitamos hacer lo siguiente:
-
Definir un agrupador de datos. A diferencia de la mayorÃa de los modelos de procesamiento del lenguaje natural (NLP), Wav2Vec2 tiene una longitud de entrada mucho mayor que la longitud de salida. Por ejemplo, una muestra de longitud de entrada 50000 tiene una longitud de salida de no más de 100. Dadas las grandes tamaños de entrada, es mucho más eficiente rellenar los lotes de entrenamiento de forma dinámica, lo que significa que todas las muestras de entrenamiento solo deben rellenarse hasta la muestra más larga en su lote y no hasta la muestra más larga en general. Por lo tanto, para ajustar finamente Wav2Vec2 se requiere un agrupador de datos especial de relleno, que definiremos a continuación
-
Métrica de evaluación. Durante el entrenamiento, el modelo debe evaluarse en la tasa de error de palabras. DeberÃamos definir una función
compute_metrics
en consecuencia -
Cargar un punto de control pre-entrenado. Necesitamos cargar un punto de control pre-entrenado y configurarlo correctamente para el entrenamiento.
-
Definir la configuración de entrenamiento.
Después de haber ajustado finamente el modelo, lo evaluaremos correctamente en los datos de prueba y verificaremos que haya aprendido a transcribir correctamente el habla.
Configurar el Trainer
Comencemos definiendo el agrupador de datos. El código del agrupador de datos se copió de este ejemplo.
Sin entrar en demasiados detalles, a diferencia de los agrupadores de datos comunes, este agrupador de datos trata los input_values
y labels
de manera diferente y, por lo tanto, aplica funciones de relleno separadas en ellos (una vez más, utilizando el administrador de contexto de Wav2Vec2). Esto es necesario porque en el habla, la entrada y la salida son de modalidades diferentes, lo que significa que no deben ser tratadas por la misma función de relleno. De manera análoga a los agrupadores de datos comunes, se rellenan los tokens de relleno en las etiquetas con -100
para que esos tokens no se tengan en cuenta al calcular la pérdida.
import torch
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union
@dataclass
class DataCollatorCTCWithPadding:
"""
Agrupador de datos que rellenará dinámicamente las entradas recibidas.
Args:
processor (:class:`~transformers.Wav2Vec2Processor`)
El procesador utilizado para procesar los datos.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, por defecto :obj:`True`):
Seleccione una estrategia para rellenar las secuencias devueltas (según el lado de relleno y el Ãndice de relleno del modelo)
entre:
* :obj:`True` o :obj:`'longest'`: Rellena hasta la secuencia más larga en el lote (o sin relleno si solo se proporciona una única secuencia).
* :obj:`'max_length'`: Rellena hasta una longitud máxima especificada con el argumento :obj:`max_length` o hasta la longitud de entrada máxima aceptable para el modelo si no se proporciona ese argumento.
* :obj:`False` o :obj:`'do_not_pad'` (predeterminado): Sin relleno (es decir, puede generar un lote con secuencias de longitudes diferentes).
max_length (:obj:`int`, `opcional`):
Longitud máxima de los ``input_values`` de la lista devuelta y opcionalmente longitud de relleno (ver arriba).
max_length_labels (:obj:`int`, `opcional`):
Longitud máxima de las listas ``labels`` devueltas y opcionalmente longitud de relleno (ver arriba).
pad_to_multiple_of (:obj:`int`, `opcional`):
Si se establece, rellenará la secuencia a un múltiplo del valor proporcionado.
Esto es especialmente útil para habilitar el uso de Tensor Cores en hardware NVIDIA con capacidad de cálculo >=
7.5 (Volta).
"""
processor: Wav2Vec2Processor
padding: Union[bool, str] = True
max_length: Optional[int] = None
max_length_labels: Optional[int] = None
pad_to_multiple_of: Optional[int] = None
pad_to_multiple_of_labels: Optional[int] = None
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# dividir las entradas y las etiquetas ya que deben tener longitudes diferentes y necesitan
# diferentes métodos de relleno
input_features = [{"input_values": feature["input_values"]} for feature in features]
label_features = [{"input_ids": feature["labels"]} for feature in features]
batch = self.processor.pad(
input_features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors="pt",
)
with self.processor.as_target_processor():
labels_batch = self.processor.pad(
label_features,
padding=self.padding,
max_length=self.max_length_labels,
pad_to_multiple_of=self.pad_to_multiple_of_labels,
return_tensors="pt",
)
# reemplazar el relleno con -100 para ignorar la pérdida correctamente
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
batch["labels"] = labels
return batch
Inicialicemos el colector de datos.
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)
A continuación, se define la métrica de evaluación. Como se mencionó anteriormente, la métrica predominante en ASR es la tasa de error de palabras (WER), por lo tanto, también la utilizaremos en este cuaderno.
wer_metric = load_metric("wer")
El modelo devolverá una secuencia de vectores de logitos:
y 1 , … , y m \mathbf{y}_1, \ldots, \mathbf{y}_m y 1 ​ , … , y m ​ ,
con y 1 = f θ ( x 1 , … , x n ) [ 0 ] \mathbf{y}_1 = f_{\theta}(x_1, \ldots, x_n)[0] y 1 ​ = f θ ​ ( x 1 ​ , … , x n ​ ) [ 0 ] y n > > m n > > m n > > m .
Un vector de logitos y 1 \mathbf{y}_1 y 1 ​ contiene las probabilidades logarÃtmicas para cada palabra en el vocabulario que definimos anteriormente, por lo tanto len ( y i ) = \text{len}(\mathbf{y}_i) = len ( y i ​ ) = config.vocab_size
. Estamos interesados en la predicción más probable del modelo, por lo que tomamos argmax(...)
de los logitos. Además, transformamos las etiquetas codificadas de vuelta a la cadena original reemplazando -100
con el pad_token_id
y decodificando los ids asegurándonos de que los tokens consecutivos no se agrupen en el mismo token en el estilo de CTC 1 {}^1 1 .
def compute_metrics(pred):
pred_logits = pred.predictions
pred_ids = np.argmax(pred_logits, axis=-1)
pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id
pred_str = processor.batch_decode(pred_ids)
# no queremos agrupar tokens al calcular las métricas
label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
wer = wer_metric.compute(predictions=pred_str, references=label_str)
return {"wer": wer}
Ahora, podemos cargar el punto de control pre-entrenado de Wav2Vec2
. El pad_token_id
del tokenizador debe definirse como el pad_token_id
del modelo o en el caso de Wav2Vec2ForCTC
también el token en blanco de CTC 2 {}^2 2 . Para ahorrar memoria de GPU, habilitamos la comprobación de gradientes de checkpoint de PyTorch y también establecemos la reducción de pérdida en “mean”.
from transformers import Wav2Vec2ForCTC
model = Wav2Vec2ForCTC.from_pretrained(
"facebook/wav2vec2-base",
ctc_loss_reduction="mean",
pad_token_id=processor.tokenizer.pad_token_id,
)
Imprimir salida:
Algunos pesos de Wav2Vec2ForCTC no se inicializaron desde el punto de control del modelo en facebook/wav2vec2-base y se inicializan de nuevo: ['lm_head.weight', 'lm_head.bias']
Probablemente deberÃas ENTRENAR este modelo en una tarea secundaria para poder usarlo para predicciones e inferencia.
El primer componente de Wav2Vec2 consiste en una pila de capas CNN que se utilizan para extraer caracterÃsticas acústicamente significativas pero contextualmente independientes de la señal de voz en bruto. Esta parte del modelo ya ha sido suficientemente entrenada durante el preentrenamiento y, como se indica en el artÃculo, no es necesario ajustarla más. Por lo tanto, podemos establecer requires_grad
en False
para todos los parámetros de la parte de extracción de caracterÃsticas.
model.freeze_feature_extractor()
En un último paso, definimos todos los parámetros relacionados con el entrenamiento. Para dar más explicación sobre algunos de los parámetros:
group_by_length
hace que el entrenamiento sea más eficiente agrupando muestras de entrenamiento de longitud de entrada similar en un lote. Esto puede acelerar significativamente el tiempo de entrenamiento al reducir en gran medida el número total de tokens de relleno inútiles que se pasan a través del modelo.learning_rate
yweight_decay
se ajustaron de forma heurÃstica hasta que el ajuste fino se volvió estable. Tenga en cuenta que estos parámetros dependen fuertemente del conjunto de datos de Timit y pueden ser subóptimos para otros conjuntos de datos de voz.
Para obtener más explicaciones sobre otros parámetros, se puede consultar la documentación.
Durante el entrenamiento, se cargará de forma asÃncrona un punto de control en el hub cada 400 pasos de entrenamiento. Esto le permite también jugar con el widget de demostración incluso mientras su modelo aún se está entrenando.
Nota: Si no se desea cargar los puntos de control del modelo en el hub, simplemente configure push_to_hub=False
.
from transformers import TrainingArguments
training_args = TrainingArguments(
output_dir=nombre_repo,
group_by_length=True,
per_device_train_batch_size=32,
evaluation_strategy="steps",
num_train_epochs=30,
fp16=True,
gradient_checkpointing=True,
save_steps=500,
eval_steps=500,
logging_steps=500,
learning_rate=1e-4,
weight_decay=0.005,
warmup_steps=1000,
save_total_limit=2,
)
Ahora, todas las instancias se pueden pasar al entrenador y estamos listos para comenzar el entrenamiento.
from transformers import Trainer
trainer = Trainer(
model=model,
data_collator=data_collator,
args=training_args,
compute_metrics=compute_metrics,
train_dataset=timit_prepared["train"],
eval_dataset=timit_prepared["test"],
tokenizer=processor.feature_extractor,
)
1 {}^1 1 Para permitir que los modelos sean independientes de la velocidad del hablante, en CTC, los tokens consecutivos que son idénticos se agrupan simplemente como un solo token. Sin embargo, las etiquetas codificadas no deben agruparse al decodificar, ya que no corresponden a los tokens predichos del modelo, por eso se debe pasar el parámetro group_tokens=False
. Si no pasáramos este parámetro, una palabra como "hello"
se codificarÃa incorrectamente y se decodificarÃa como "helo"
.
2 {}^2 2 El token en blanco permite que el modelo prediga una palabra, como "hello"
, al obligarlo a insertar el token en blanco entre las dos “l”. Una predicción conforme a CTC de "hello"
de nuestro modelo serÃa [PAD] [PAD] "h" "e" "e" "l" "l" [PAD] "l" "o" "o" [PAD]
.
Entrenamiento
El entrenamiento tomará entre 90 y 180 minutos dependiendo de la GPU asignada al Google Colab adjunto a este cuaderno. Si bien el modelo entrenado produce resultados satisfactorios en los datos de prueba de Timit, no es de ninguna manera un modelo optimizado y ajustado al máximo. El propósito de este cuaderno es demostrar cómo se pueden ajustar los puntos de control de Wav2Vec2 base, grande y grande-lv60 en cualquier conjunto de datos en inglés.
En caso de que desee usar este Google Colab para ajustar su modelo, asegúrese de que su entrenamiento no se detenga debido a la inactividad. Un truco simple para evitar esto es pegar el siguiente código en la consola de esta pestaña (haga clic derecho -> inspeccionar -> pestaña Consola e inserte el código).
function ConnectButton(){
console.log("Connect pushed");
document.querySelector("#top-toolbar > colab-connect-button").shadowRoot.querySelector("#connect").click()
}
setInterval(ConnectButton,60000);
trainer.train()
Dependiendo de su GPU, es posible que vea un error de "out-of-memory"
aquÃ. En este caso, lo mejor será reducir per_device_train_batch_size
a 16 o incluso menos y eventualmente hacer uso de gradient_accumulation
.
Resultado de la impresión:
El WER final deberÃa ser inferior a 0.3, lo cual es razonable dado que las tasas de error de fonemas (PER) de última generación son ligeramente inferiores a 0.1 (consulte la tabla de clasificación) y que el WER suele ser peor que el PER.
Ahora puede cargar el resultado del entrenamiento en el Hub, simplemente ejecute esta instrucción:
trainer.push_to_hub()
Ahora puede compartir este modelo con todos sus amigos, familiares, mascotas favoritas: todos pueden cargarlo con el identificador “su-nombre-de-usuario/el-nombre-que-eligió”, por ejemplo:
from transformers import AutoModelForCTC, Wav2Vec2Processor
model = AutoModelForCTC.from_pretrained("patrickvonplaten/wav2vec2-base-timit-demo-colab")
processor = Wav2Vec2Processor.from_pretrained("patrickvonplaten/wav2vec2-base-timit-demo-colab")
Evaluación
En la parte final, evaluamos nuestro modelo afinado en el conjunto de pruebas y jugamos un poco con él.
Carguemos el procesador
y el modelo
.
procesador = Wav2Vec2Processor.from_pretrained(repo_name)
modelo = Wav2Vec2ForCTC.from_pretrained(repo_name)
Ahora, haremos uso de la función map(...)
para predecir la transcripción de cada muestra de prueba y guardar la predicción en el propio conjunto de datos. Llamaremos al diccionario resultante "resultados"
.
Nota: evaluamos el conjunto de datos de prueba con batch_size=1
a propósito debido a este problema. Dado que las entradas acolchadas no producen la misma salida exacta que las entradas no acolchadas, se puede lograr un mejor WER no acolchando la entrada en absoluto.
def mapear_a_resultado(lote):
with torch.no_grad():
valores_entrada = torch.tensor(lote["input_values"], device="cuda").unsqueeze(0)
logits = modelo(valores_entrada).logits
ids_pred = torch.argmax(logits, dim=-1)
lote["pred_str"] = procesador.batch_decode(ids_pred)[0]
lote["texto"] = procesador.decode(lote["labels"], group_tokens=False)
return lote
resultados = timit["test"].map(mapear_a_resultado, remove_columns=timit["test"].column_names)
Calculemos ahora el WER global.
print("WER de prueba: {:.3f}".format(wer_metric.compute(predictions=resultados["pred_str"], references=resultados["texto"])))
Resultado de impresión:
WER de prueba: 0.221
¡22.1% WER – no está mal! Nuestro modelo de demostración probablemente habrÃa pasado a la lista oficial.
Echemos un vistazo a algunas predicciones para ver qué errores comete el modelo.
Resultado de impresión:
mostrar_elementos_aleatorios(resultados.remove_columns(["speech", "sampling_rate"]))
Queda claro que las transcripciones predichas son acústicamente muy similares a las transcripciones objetivo, pero a menudo contienen errores ortográficos o gramaticales. Esto no deberÃa ser muy sorprendente, dado que nos basamos exclusivamente en Wav2Vec2 sin hacer uso de un modelo de lenguaje.
Finalmente, para comprender mejor cómo funciona CTC, vale la pena echar un vistazo más profundo a la salida exacta del modelo. Ejecutemos la primera muestra de prueba a través del modelo, tomemos los IDs predichos y convirtámoslos en sus tokens correspondientes.
modelo.to("cuda")
with torch.no_grad():
logits = modelo(torch.tensor(timit["test"][:1]["input_values"], device="cuda")).logits
ids_pred = torch.argmax(logits, dim=-1)
# convertir IDs a tokens
" ".join(procesador.tokenizer.convert_ids_to_tokens(ids_pred[0].tolist()))
Resultado de impresión:
[PAD] [PAD] [PAD] [PAD] [PAD] [PAD] t t h e e | | b b [PAD] u u n n n g g [PAD] a [PAD] [PAD] l l [PAD] o o o [PAD] | w w a a [PAD] s s | | [PAD] [PAD] p l l e e [PAD] [PAD] s s e n n t t t [PAD] l l y y | | | s s [PAD] i i [PAD] t t t [PAD] u u u u [PAD] [PAD] [PAD] a a [PAD] t t e e e d d d | n n e e a a a r | | t h h e | | s s h h h [PAD] o o o [PAD] o o r r [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]
La salida deberÃa aclarar un poco cómo funciona CTC en la práctica. El modelo es en cierta medida invariante a la velocidad de habla, ya que ha aprendido a repetir el mismo token en caso de que el fragmento de habla a clasificar aún corresponda al mismo token. Esto hace que CTC sea un algoritmo muy poderoso para el reconocimiento de voz, ya que la transcripción del archivo de voz suele ser muy independiente de su longitud.
Una vez más, aconsejo al lector que eche un vistazo a esta entrada de blog muy interesante para comprender mejor CTC.