Ajuste fino de Whisper para ASR multilingüe con 🤗 Transformers.
Ajuste fino de Whisper para ASR multilingüe con Transformers.
En este blog, presentamos una guía paso a paso sobre cómo ajustar finamente Whisper para cualquier conjunto de datos ASR multilingüe utilizando Hugging Face 🤗 Transformers. Este blog proporciona explicaciones detalladas del modelo Whisper, el conjunto de datos Common Voice y la teoría detrás del ajuste fino, con celdas de código acompañantes para ejecutar los pasos de preparación de datos y ajuste fino. Para una versión más simplificada del cuaderno con menos explicaciones pero todo el código, consulte el Google Colab adjunto.
Tabla de Contenidos
- Introducción
- Ajuste fino de Whisper en Google Colab
- Preparar el Entorno
- Cargar el Conjunto de Datos
- Preparar el Extractor de Características, Tokenizador y Datos
- Entrenamiento y Evaluación
- Construir una Demo
- Conclusiones
Introducción
Whisper es un modelo pre-entrenado para reconocimiento automático del habla (ASR) publicado en septiembre de 2022 por los autores Alec Radford et al. de OpenAI. A diferencia de muchos de sus predecesores, como Wav2Vec 2.0, que se pre-entrenan en datos de audio sin etiquetar, Whisper se pre-entrena en una cantidad vasta de datos de audio-transcripción etiquetados, precisamente 680,000 horas. Esto es una magnitud de datos más que los datos de audio sin etiquetar utilizados para entrenar Wav2Vec 2.0 (60,000 horas). Además, 117,000 horas de estos datos de pre-entrenamiento son datos ASR multilingües. Esto resulta en puntos de control que se pueden aplicar a más de 96 idiomas, muchos de los cuales se consideran de recursos limitados.
Esta cantidad de datos etiquetados permite que Whisper se pre-entrene directamente en la tarea supervisada de reconocimiento del habla, aprendiendo una asignación de habla a texto a partir de los datos de pre-entrenamiento de audio-transcripción etiquetados 1 {}^1 1 . Como consecuencia, Whisper requiere poco ajuste fino adicional para producir un modelo ASR de alto rendimiento. Esto contrasta con Wav2Vec 2.0, que se pre-entrena en la tarea no supervisada de predicción enmascarada. Aquí, el modelo se entrena para aprender una asignación intermedia de habla a estados ocultos a partir de datos de audio sin etiquetar solamente. Si bien el pre-entrenamiento no supervisado produce representaciones de habla de alta calidad, no aprende una asignación de habla a texto. Esta asignación solo se aprende durante el ajuste fino, lo que requiere más ajuste fino para obtener un rendimiento competitivo.
Cuando se escalan a 680,000 horas de datos de pre-entrenamiento etiquetados, los modelos Whisper demuestran una fuerte capacidad para generalizar en muchos conjuntos de datos y dominios. Los puntos de control pre-entrenados logran resultados competitivos en comparación con los sistemas ASR de última generación, con una tasa de error de palabras (WER) cercana al 3% en el subconjunto de prueba limpio de LibriSpeech ASR y un nuevo estado de arte en TED-LIUM con un 4.7% WER ( véase la Tabla 8 del artículo Whisper ). El conocimiento extenso de ASR multilingüe adquirido por Whisper durante el pre-entrenamiento se puede aprovechar para otros idiomas de recursos limitados; mediante el ajuste fino, los puntos de control pre-entrenados se pueden adaptar a conjuntos de datos y idiomas específicos para mejorar aún más estos resultados.
- Entrenamiento de difusión estable con Dreambooth utilizando difusores
- Generando texto a nivel humano con búsqueda contrastiva en Transfor...
- Director de Perspicacias de Aprendizaje Automático [Parte 4]
Whisper es un modelo codificador-decodificador basado en Transformer, también conocido como un modelo secuencia-a-secuencia. Mapea una secuencia de características de espectrograma de audio a una secuencia de tokens de texto. Primero, las entradas de audio sin procesar se convierten en un espectrograma de log-Mel mediante el extractor de características. Luego, el codificador Transformer codifica el espectrograma para formar una secuencia de estados ocultos del codificador. Finalmente, el decodificador predice autoregresivamente los tokens de texto, condicionados tanto por los tokens previos como por los estados ocultos del codificador. La Figura 1 resume el modelo Whisper.
En un modelo de secuencia a secuencia, el codificador transforma las entradas de audio en un conjunto de representaciones de estado oculto, extrayendo características importantes del habla pronunciada. El decodificador desempeña el papel de un modelo de lenguaje, procesando las representaciones de estado oculto y generando las correspondientes transcripciones de texto. La incorporación de un modelo de lenguaje internamente en la arquitectura del sistema se denomina fusión profunda. Esto contrasta con la fusión superficial, donde un modelo de lenguaje se combina externamente con un codificador, como en CTC + n n n-grama (cf. Estimación de modelo de lenguaje interno). Con la fusión profunda, todo el sistema se puede entrenar de principio a fin con los mismos datos de entrenamiento y función de pérdida, lo que brinda una mayor flexibilidad y un rendimiento generalmente superior (cf. Referencia ESB).
Whisper se preentrena y ajusta utilizando la función objetivo de entropía cruzada, una función objetivo estándar para entrenar sistemas de secuencia a secuencia en tareas de clasificación. Aquí, el sistema se entrena para clasificar correctamente el token de texto objetivo de un vocabulario predefinido de tokens de texto.
Los puntos de control de Whisper vienen en cinco configuraciones de diferentes tamaños de modelo. Los cuatro más pequeños se entrenan solo con datos en inglés o multilingües. El punto de control más grande es solo multilingüe. Los nueve puntos de control preentrenados están disponibles en el Hugging Face Hub. Los puntos de control se resumen en la siguiente tabla con enlaces a los modelos en el Hub:
Para fines de demostración, ajustaremos finamente la versión multilingüe del punto de control small
con 244M de parámetros (~= 1GB). En cuanto a nuestros datos, entrenaremos y evaluaremos nuestro sistema en un idioma de recursos limitados tomado del conjunto de datos Common Voice. Mostraremos que con tan solo 8 horas de datos de ajuste fino, podemos lograr un buen rendimiento en este idioma.
1 {}^1 1 El nombre Whisper proviene del acrónimo “WSPSR”, que significa “Preentrenamiento supervisado a escala web para reconocimiento de voz”.
Ajuste fino de Whisper en Google Colab
Preparar el entorno
Utilizaremos varios paquetes populares de Python para ajustar finamente el modelo Whisper. Usaremos datasets
para descargar y preparar nuestros datos de entrenamiento y transformers
para cargar y entrenar nuestro modelo Whisper. También necesitaremos el paquete soundfile
para preprocesar archivos de audio, evaluate
y jiwer
para evaluar el rendimiento de nuestro modelo. Finalmente, usaremos gradio
para construir una demostración llamativa de nuestro modelo ajustado finamente.
!pip install datasets>=2.6.1
!pip install git+https://github.com/huggingface/transformers
!pip install librosa
!pip install evaluate>=0.30
!pip install jiwer
!pip install gradio
Le recomendamos encarecidamente que cargue los puntos de control del modelo directamente en el Hugging Face Hub durante el entrenamiento. El Hub proporciona:
- Control de versiones integrado: puede estar seguro de que no se pierde ningún punto de control del modelo durante el entrenamiento.
- Registros de Tensorboard: realiza un seguimiento de las métricas importantes a lo largo del entrenamiento.
- Model cards: documenta qué hace un modelo y sus casos de uso previstos.
- Comunidad: ¡una forma fácil de compartir y colaborar con la comunidad!
Enlazar el cuaderno al Hub es sencillo: simplemente ingrese su token de autenticación del Hub cuando se le solicite. Encuentre su token de autenticación del Hub aquí:
from huggingface_hub import notebook_login
notebook_login()
Salida impresa:
Login successful
Your token has been saved to /root/.huggingface/token
Cargar conjunto de datos
Common Voice es una serie de conjuntos de datos de crowdsourcing donde los hablantes graban texto de Wikipedia en varios idiomas. Utilizaremos la última edición del conjunto de datos Common Voice (versión 11). En cuanto a nuestro idioma, ajustaremos finamente nuestro modelo en hindi, un idioma indoario hablado en el norte, centro, este y oeste de India. Common Voice 11.0 contiene aproximadamente 12 horas de datos de hindi etiquetados, 4 de los cuales son datos de prueba excluidos.
Vamos al Hub y veamos la página del conjunto de datos de Common Voice: mozilla-foundation/common_voice_11_0.
La primera vez que veamos esta página, se nos pedirá que aceptemos los términos de uso. Después de eso, se nos dará acceso completo al conjunto de datos.
Una vez que hayamos proporcionado autenticación para usar el conjunto de datos, se nos presentará una vista previa del conjunto de datos. La vista previa del conjunto de datos nos muestra los primeros 100 ejemplos del conjunto de datos. Además, está cargado con muestras de audio listas para que las escuchemos en tiempo real. Podemos seleccionar el subconjunto en hindi de Common Voice estableciendo el subconjunto en hi
utilizando el menú desplegable (hi
es el código identificador de idioma para hindi):
Si presionamos el botón de reproducción en el primer ejemplo, podemos escuchar el audio y ver el texto correspondiente. Recorre las muestras de los conjuntos de entrenamiento y prueba para tener una mejor idea de los datos de audio y texto con los que estamos trabajando. Puedes notar por la entonación y el estilo que las grabaciones se toman del habla narrada. También es probable que notes la gran variación en los hablantes y la calidad de las grabaciones, una característica común de los datos generados por la multitud.
Usando 🤗 Datasets, descargar y preparar datos es extremadamente sencillo. Podemos descargar y preparar las divisiones de Common Voice en solo una línea de código. Dado que el hindi tiene muy pocos recursos, combinaremos las divisiones de train
y validation
para obtener aproximadamente 8 horas de datos de entrenamiento. Utilizaremos las 4 horas de datos de test
como nuestro conjunto de prueba retenido:
from datasets import load_dataset, DatasetDict
common_voice = DatasetDict()
common_voice["train"] = load_dataset("mozilla-foundation/common_voice_11_0", "hi", split="train+validation", use_auth_token=True)
common_voice["test"] = load_dataset("mozilla-foundation/common_voice_11_0", "hi", split="test", use_auth_token=True)
print(common_voice)
Resultado de impresión:
DatasetDict({
train: Dataset({
features: ['client_id', 'path', 'audio', 'sentence', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment'],
num_rows: 6540
})
test: Dataset({
features: ['client_id', 'path', 'audio', 'sentence', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment'],
num_rows: 2894
})
})
La mayoría de los conjuntos de datos ASR solo proporcionan muestras de audio de entrada (audio
) y el texto transcrito correspondiente (sentence
). Common Voice contiene información adicional de metadatos, como accent
y locale
, que podemos ignorar para ASR. Manteniendo el cuaderno lo más general posible, solo consideramos el audio de entrada y el texto transcrito para el ajuste fino, descartando la información adicional de los metadatos:
common_voice = common_voice.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "path", "segment", "up_votes"])
Common Voice es solo uno de los conjuntos de datos ASR multilingües que podemos descargar desde el Hub; ¡hay muchos más disponibles para nosotros! Para ver la variedad de conjuntos de datos disponibles para el reconocimiento del habla, sigue el enlace: Conjuntos de datos ASR en el Hub .
Preparar el extractor de características, el tokenizador y los datos
El pipeline de ASR se puede descomponer en tres componentes:
- Un extractor de características que preprocesa las entradas de audio sin procesar
- El modelo que realiza el mapeo de secuencia a secuencia
- Un tokenizador que postprocesa las salidas del modelo en formato de texto
En 🤗 Transformers, el modelo Whisper tiene un extractor de características y un tokenizador asociados, llamados WhisperFeatureExtractor y WhisperTokenizer respectivamente.
¡Pasaremos por los detalles del extractor de características y el tokenizador uno por uno!
Cargar WhisperFeatureExtractor
El habla se representa mediante una matriz unidimensional que varía con el tiempo. El valor de la matriz en cualquier momento dado es la amplitud de la señal en ese punto. A partir de la información de amplitud solamente, podemos reconstruir el espectro de frecuencia del audio y recuperar todas las características acústicas.
Dado que el habla es continua, contiene un número infinito de valores de amplitud. Esto plantea problemas para los dispositivos informáticos que esperan matrices finitas. Por lo tanto, discretizamos nuestra señal de voz muestreando valores de nuestra señal en intervalos de tiempo fijos. El intervalo con el que muestreamos nuestro audio se conoce como la frecuencia de muestreo y generalmente se mide en muestras/segundo o Hertz (Hz). Muestrear con una frecuencia de muestreo más alta da como resultado una mejor aproximación de la señal de voz continua, pero también requiere almacenar más valores por segundo.
Es crucial que coincidamos la tasa de muestreo de nuestras entradas de audio con la tasa de muestreo esperada por nuestro modelo, ya que las señales de audio con diferentes tasas de muestreo tienen distribuciones muy diferentes. Las muestras de audio solo deben procesarse con la tasa de muestreo correcta. ¡No hacerlo puede llevar a resultados inesperados! Por ejemplo, tomar una muestra de audio con una tasa de muestreo de 16 kHz y escucharla con una tasa de muestreo de 8 kHz hará que el audio suene como si estuviera a la mitad de velocidad. De la misma manera, pasar audio con la tasa de muestreo incorrecta puede afectar a un modelo ASR que espera una tasa de muestreo y recibe otra. El extractor de características Whisper espera entradas de audio con una tasa de muestreo de 16 kHz, por lo que debemos igualar nuestras entradas a este valor. ¡No queremos entrenar inadvertidamente un sistema ASR en habla a cámara lenta!
El extractor de características Whisper realiza dos operaciones. Primero, rellena/trunca un lote de muestras de audio para que todas las muestras tengan una longitud de entrada de 30 segundos. Las muestras más cortas de 30 segundos se rellenan con ceros al final de la secuencia (ceros en una señal de audio que corresponden a ninguna señal o silencio). Las muestras más largas de 30 segundos se truncan a 30 segundos. Dado que todos los elementos del lote se rellenan/truncan a una longitud máxima en el espacio de entrada, no requerimos una máscara de atención al enviar las entradas de audio al modelo Whisper. Whisper es único en este sentido: con la mayoría de los modelos de audio, puede esperar proporcionar una máscara de atención que detalle dónde se han rellenado las secuencias y, por lo tanto, dónde se deben ignorar en el mecanismo de autoatención. Whisper está entrenado para funcionar sin una máscara de atención e inferir directamente de las señales de habla dónde ignorar las entradas.
La segunda operación que realiza el extractor de características Whisper es convertir los arreglos de audio rellenos a espectrogramas de log-Mel. Estos espectrogramas son una representación visual de las frecuencias de una señal, como una transformada de Fourier. Se muestra un ejemplo de espectrograma en la Figura 2. A lo largo del eje y están los canales de Mel, que corresponden a ciertos intervalos de frecuencia. A lo largo del eje x está el tiempo. El color de cada píxel corresponde a la intensidad logarítmica de ese intervalo de frecuencia en un momento dado. El espectrograma de log-Mel es la forma de entrada esperada por el modelo Whisper.
Los canales de Mel (intervalos de frecuencia) son estándar en el procesamiento del habla y se eligen para aproximar el rango auditivo humano. Todo lo que necesitamos saber para el ajuste fino de Whisper es que el espectrograma es una representación visual de las frecuencias en la señal de habla. Para más detalles sobre los canales de Mel, consulte “cepstrum de frecuencia de Mel”.

Afortunadamente para nosotros, el extractor de características Whisper de 🤗 Transformers realiza tanto el relleno como la conversión a espectrograma en solo una línea de código. ¡Sigamos adelante y carguemos el extractor de características desde el punto de control preentrenado para tenerlo listo para nuestros datos de audio:
from transformers import WhisperFeatureExtractor
extractor_de_caracteristicas = WhisperFeatureExtractor.from_pretrained("openai/whisper-small")
Cargar WhisperTokenizer
Ahora veamos cómo cargar un tokenizador de Whisper. El modelo Whisper devuelve tokens de texto que indican el índice del texto predicho entre los elementos del diccionario de vocabulario. El tokenizador asigna una secuencia de tokens de texto a la cadena de texto real (por ejemplo, [1169, 3797, 3332] -> “el gato se sentó”).
Tradicionalmente, al usar modelos solo de codificador para ASR, decodificamos usando la Clasificación Temporal Conectiva (CTC). Aquí se nos exige entrenar un tokenizador CTC para cada conjunto de datos que usemos. Una de las ventajas de usar una arquitectura codificador-decodificador es que podemos aprovechar directamente el tokenizador del modelo preentrenado.
El tokenizador Whisper está preentrenado en las transcripciones de los 96 idiomas de preentrenamiento. En consecuencia, tiene un byte-pair extenso que es apropiado para casi todas las aplicaciones multilingües de ASR. Para hindi, podemos cargar el tokenizador y usarlo para el ajuste fino sin ninguna modificación adicional. Simplemente tenemos que especificar el idioma objetivo y la tarea. Estos argumentos informan al tokenizador que agregue los tokens de idioma y tarea al inicio de las secuencias de etiquetas codificadas:
from transformers import WhisperTokenizer
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small", language="Hindi", task="transcribe")
Podemos verificar que el tokenizador codifica correctamente los caracteres en hindi mediante la codificación y decodificación de la primera muestra del conjunto de datos de Common Voice. Al codificar las transcripciones, el tokenizador agrega ‘tokens especiales’ al inicio y al final de la secuencia, incluyendo los tokens de inicio/fin de transcripción, el token de idioma y los tokens de tarea (como se especifica en los argumentos del paso anterior). Al decodificar los identificadores de etiquetas, tenemos la opción de ‘omitir’ estos tokens especiales, lo que nos permite devolver una cadena en la forma de entrada original:
input_str = common_voice["train"][0]["sentence"]
labels = tokenizer(input_str).input_ids
decoded_with_special = tokenizer.decode(labels, skip_special_tokens=False)
decoded_str = tokenizer.decode(labels, skip_special_tokens=True)
print(f"Entrada: {input_str}")
print(f"Decodificado con especiales: {decoded_with_special}")
print(f"Decodificado sin especiales: {decoded_str}")
print(f"Son iguales: {input_str == decoded_str}")
Resultado de la impresión:
Entrada: खीर की मिठास पर गरमाई बिहार की सियासत, कुशवाहा ने दी सफाई
Decodificado con especiales: <|startoftranscript|><|hi|><|transcribe|><|notimestamps|>खीर की मिठास पर गरमाई बिहार की सियासत, कुशवाहा ने दी सफाई<|endoftext|>
Decodificado sin especiales: खीर की मिठास पर गरमाई बिहार की सियासत, कुशवाहा ने दी सफाई
Son iguales: True
Combinar para crear un WhisperProcessor
Para simplificar el uso del extractor de características y el tokenizador, podemos envolver ambos en una única clase WhisperProcessor
. Este objeto procesador hereda de WhisperFeatureExtractor
y WhisperProcessor
y se puede utilizar en las entradas de audio y las predicciones del modelo según sea necesario. Al hacerlo, solo necesitamos realizar un seguimiento de dos objetos durante el entrenamiento: el procesador
y el modelo
:
from transformers import WhisperProcessor
procesador = WhisperProcessor.from_pretrained("openai/whisper-small", language="Hindi", task="transcribe")
Preparar los datos
Imprimamos el primer ejemplo del conjunto de datos de Common Voice para ver en qué forma está la información:
print(common_voice["train"][0])
Resultado de la impresión:
{'audio': {'path': '/home/sanchit_huggingface_co/.cache/huggingface/datasets/downloads/extracted/607848c7e74a89a3b5225c0fa5ffb9470e39b7f11112db614962076a847f3abf/cv-corpus-11.0-2022-09-21/hi/clips/common_voice_hi_25998259.mp3',
'array': array([0.0000000e+00, 0.0000000e+00, 0.0000000e+00, ..., 9.6724887e-07,
1.5334779e-06, 1.0415988e-06], dtype=float32),
'sampling_rate': 48000},
'sentence': 'खीर की मिठास पर गरमाई बिहार की सियासत, कुशवाहा ने दी सफाई'}
Podemos ver que tenemos una matriz de audio de entrada unidimensional y la correspondiente transcripción objetivo. Hemos hablado mucho sobre la importancia de la frecuencia de muestreo y el hecho de que necesitamos que coincida con la frecuencia de muestreo de nuestro modelo Whisper (16 kHz). Dado que nuestra entrada de audio se muestrea a 48 kHz, debemos reducir su frecuencia de muestreo a 16 kHz antes de pasarla al extractor de características Whisper.
Vamos a establecer las entradas de audio en la tasa de muestreo correcta utilizando el método cast_column
del conjunto de datos. Esta operación no cambia el audio en su lugar, sino que indica a datasets
que vuelva a muestrear las muestras de audio sobre la marcha la primera vez que se carguen:
from datasets import Audio
common_voice = common_voice.cast_column("audio", Audio(sampling_rate=16000))
Al volver a cargar la primera muestra de audio en el conjunto de datos de Common Voice, se volverá a muestrear a la tasa de muestreo deseada:
print(common_voice["train"][0])
Salida de impresión:
{'audio': {'path': '/home/sanchit_huggingface_co/.cache/huggingface/datasets/downloads/extracted/607848c7e74a89a3b5225c0fa5ffb9470e39b7f11112db614962076a847f3abf/cv-corpus-11.0-2022-09-21/hi/clips/common_voice_hi_25998259.mp3',
'array': array([ 0.0000000e+00, 0.0000000e+00, 0.0000000e+00, ...,
-3.4206650e-07, 3.2979898e-07, 1.0042874e-06], dtype=float32),
'sampling_rate': 16000},
'sentence': 'खीर की मिठास पर गरमाई बिहार की सियासत, कुशवाहा ने दी सफाई'}
¡Genial! Podemos ver que la tasa de muestreo se ha reducido a 16 kHz. Los valores de la matriz también son diferentes, ya que ahora solo tenemos aproximadamente un valor de amplitud por cada tres que teníamos antes.
Ahora podemos escribir una función para preparar nuestros datos listos para el modelo:
- Cargamos y volvemos a muestrear los datos de audio llamando a
batch["audio"]
. Como se explicó anteriormente, 🤗 Datasets realiza cualquier operación de re-muestreo necesaria sobre la marcha. - Utilizamos el extractor de características para calcular las características de entrada del espectrograma log-Mel a partir de nuestra matriz de audio unidimensional.
- Codificamos las transcripciones en identificadores de etiquetas mediante el uso del tokenizador.
def prepare_dataset(batch):
# cargar y volver a muestrear los datos de audio de 48 a 16 kHz
audio = batch["audio"]
# calcular características de entrada de log-Mel a partir de la matriz de audio de entrada
batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
# codificar el texto de destino en identificadores de etiquetas
batch["labels"] = tokenizer(batch["sentence"]).input_ids
return batch
Podemos aplicar la función de preparación de datos a todos nuestros ejemplos de entrenamiento utilizando el método .map
del conjunto de datos:
common_voice = common_voice.map(prepare_dataset, remove_columns=common_voice.column_names["train"], num_proc=4)
¡Listo! ¡Ahora tenemos nuestros datos completamente preparados para el entrenamiento! Continuemos y veamos cómo podemos usar estos datos para ajustar finamente Whisper.
Nota: Actualmente, datasets
utiliza tanto torchaudio
como librosa
para cargar y re-muestrear audio. Si desea implementar su propia carga/re-muestreo de datos personalizados, puede utilizar la columna "path"
para obtener la ruta del archivo de audio y omitir la columna "audio"
.
Entrenamiento y Evaluación
Ahora que hemos preparado nuestros datos, estamos listos para sumergirnos en el pipeline de entrenamiento. El 🤗 Trainer hará gran parte del trabajo pesado por nosotros. Todo lo que tenemos que hacer es:
-
Definir un colector de datos: el colector de datos toma nuestros datos preprocesados y prepara tensores de PyTorch listos para el modelo.
-
Métricas de evaluación: durante la evaluación, queremos evaluar el modelo utilizando la métrica de tasa de error de palabras (WER). Necesitamos definir una función
compute_metrics
que maneje este cálculo. -
Cargar un punto de control pre-entrenado: necesitamos cargar un punto de control pre-entrenado y configurarlo correctamente para el entrenamiento.
-
Definir los argumentos de entrenamiento: estos serán utilizados por el 🤗 Trainer para construir el programa de entrenamiento.
Una vez que hayamos afinado el modelo, lo evaluaremos en los datos de prueba para verificar que lo hemos entrenado correctamente para transcribir el habla en hindi.
Definir un Collator de Datos
El collator de datos para un modelo de habla de secuencia a secuencia es único en el sentido de que trata las input_features
y las labels
de forma independiente: las input_features
deben ser manejadas por el extractor de características y las labels
por el tokenizador.
Las input_features
ya están rellenadas a 30 segundos y convertidas en un espectrograma de log-Mel de dimensión fija, así que lo único que tenemos que hacer es convertirlas en tensores batch de PyTorch. Esto lo hacemos utilizando el método .pad
del extractor de características con return_tensors=pt
. Cabe destacar que aquí no se aplica ningún relleno adicional, ya que las entradas tienen una dimensión fija, simplemente se convierten en tensores PyTorch.
Por otro lado, las labels
no están rellenadas. Primero rellenamos las secuencias a la longitud máxima en el batch utilizando el método .pad
del tokenizador. Luego, los tokens de relleno se reemplazan por -100
para que estos tokens no se tengan en cuenta al calcular la pérdida. A continuación, eliminamos el token de inicio de transcripción del principio de la secuencia de etiquetas, ya que lo añadiremos más tarde durante el entrenamiento.
Podemos aprovechar el WhisperProcessor
que definimos anteriormente para realizar tanto las operaciones de extracción de características como las de tokenización:
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Union
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
processor: Any
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# dividimos las entradas y las etiquetas ya que deben tener longitudes diferentes y necesitan métodos de relleno diferentes
# primero tratamos las entradas de audio simplemente devolviendo tensores de PyTorch
input_features = [{"input_features": feature["input_features"]} for feature in features]
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
# obtenemos las secuencias de etiquetas tokenizadas
label_features = [{"input_ids": feature["labels"]} for feature in features]
# rellenamos las etiquetas hasta la longitud máxima
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
# reemplazamos el relleno con -100 para ignorar correctamente la pérdida
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
# si el token bos se añadió en el paso de tokenización anterior,
# cortamos el token bos aquí ya que se añade más tarde de todos modos
if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
labels = labels[:, 1:]
batch["labels"] = labels
return batch
Inicialicemos el collator de datos que acabamos de definir:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
Métricas de Evaluación
A continuación, definimos la métrica de evaluación que utilizaremos en nuestro conjunto de evaluación. Utilizaremos la métrica de Tasa de Error de Palabras (WER), la métrica “de facto” para evaluar los sistemas de ASR. Para obtener más información, consulta la documentación de WER. Cargaremos la métrica de WER de 🤗 Evaluate:
import evaluate
metric = evaluate.load("wer")
Luego, simplemente tenemos que definir una función que tome las predicciones de nuestro modelo y devuelva la métrica de WER. Esta función, llamada compute_metrics
, primero reemplaza -100
con el pad_token_id
en los label_ids
(deshaciendo el paso que aplicamos en el collator de datos para ignorar correctamente los tokens de relleno en la pérdida). Luego, decodifica los ids predichos y de etiquetas a cadenas de texto. Por último, calcula la WER entre las predicciones y las etiquetas de referencia:
def compute_metrics(pred):
pred_ids = pred.predictions
label_ids = pred.label_ids
# reemplazamos -100 con pad_token_id
label_ids[label_ids == -100] = tokenizer.pad_token_id
# no queremos agrupar los tokens al calcular las métricas
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
wer = 100 * metric.compute(predictions=pred_str, references=label_str)
return {"wer": wer}
Cargar un punto de control pre-entrenado
Ahora carguemos el punto de control pre-entrenado de Whisper small
. ¡Nuevamente, esto es trivial a través del uso de 🤗 Transformers!
from transformers import WhisperForConditionalGeneration
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
El modelo Whisper tiene identificadores de token que se fuerzan como salidas del modelo antes de que comience la generación autoregresiva ( forced_decoder_ids
). Estos identificadores de token controlan el lenguaje de transcripción y la tarea para ASR sin entrenamiento previo. Para el ajuste fino, estableceremos estos identificadores en None
, ya que entrenaremos el modelo para predecir el lenguaje correcto (Hindi) y la tarea (transcripción). También hay tokens que se suprimen por completo durante la generación ( suppress_tokens
). Estos tokens tienen sus probabilidades logarítmicas establecidas en -inf
, de manera que nunca se muestrean. Anularemos estos tokens a una lista vacía, lo que significa que no se suprimen tokens:
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []
Definir los argumentos de entrenamiento
En el último paso, definimos todos los parámetros relacionados con el entrenamiento. A continuación se explican algunos de los parámetros:
output_dir
: directorio local donde se guardarán los pesos del modelo. Este también será el nombre del repositorio en Hugging Face Hub .generation_max_length
: número máximo de tokens para generar autoregresivamente durante la evaluación.save_steps
: durante el entrenamiento, se guardarán puntos de control intermedios y se cargarán de forma asíncrona en el Hub cadasave_steps
pasos de entrenamiento.eval_steps
: durante el entrenamiento, se realizará una evaluación de los puntos de control intermedios cadaeval_steps
pasos de entrenamiento.report_to
: dónde guardar los registros de entrenamiento. Las plataformas compatibles son"azure_ml"
,"comet_ml"
,"mlflow"
,"neptune"
,"tensorboard"
y"wandb"
. Elige tu favorita o déjalo como"tensorboard"
para guardar registros en el Hub.
Para más detalles sobre los otros argumentos de entrenamiento, consulta la documentación de Seq2SeqTrainingArguments .
from transformers import Seq2SeqTrainingArguments
training_args = Seq2SeqTrainingArguments(
output_dir="./whisper-small-hi", # cambia por el nombre de repositorio que prefieras
per_device_train_batch_size=16,
gradient_accumulation_steps=1, # aumenta el doble por cada reducción a la mitad en el tamaño del lote
learning_rate=1e-5,
warmup_steps=500,
max_steps=4000,
gradient_checkpointing=True,
fp16=True,
evaluation_strategy="steps",
per_device_eval_batch_size=8,
predict_with_generate=True,
generation_max_length=225,
save_steps=1000,
eval_steps=1000,
logging_steps=25,
report_to=["tensorboard"],
load_best_model_at_end=True,
metric_for_best_model="wer",
greater_is_better=False,
push_to_hub=True,
)
Nota : si no deseas cargar los puntos de control del modelo en el Hub, establece push_to_hub=False
.
Podemos pasar los argumentos de entrenamiento al 🤗 Trainer junto con nuestro modelo, conjunto de datos, recolector de datos y función compute_metrics
:
from transformers import Seq2SeqTrainer
trainer = Seq2SeqTrainer(
args=training_args,
model=model,
train_dataset=common_voice["train"],
eval_dataset=common_voice["test"],
data_collator=data_collator,
compute_metrics=compute_metrics,
tokenizer=processor.feature_extractor,
)
¡Y con eso, estamos listos para comenzar el entrenamiento!
Entrenamiento
Para iniciar el entrenamiento, simplemente ejecuta:
trainer.train()
El entrenamiento tomará aproximadamente de 5 a 10 horas, dependiendo de tu GPU o la asignada a Google Colab. Dependiendo de tu GPU, es posible que te encuentres con un error de CUDA "out-of-memory"
al iniciar el entrenamiento. En este caso, puedes reducir el per_device_train_batch_size
incrementalmente en factores de 2 y utilizar gradient_accumulation_steps
para compensar.
Resultado de la impresión:
Nuestro mejor WER es del 32.0% – ¡no está mal para 8 horas de datos de entrenamiento! La gran pregunta es cómo se compara esto con otros sistemas ASR. Para eso, podemos ver el hf-speech-bench
, una tabla de clasificación que categoriza los modelos por idioma y conjunto de datos, y posteriormente los clasifica según su WER.
Nuestro modelo afinado mejora significativamente el rendimiento de cero disparo del punto de control Whisper small
, destacando las fuertes capacidades de transferencia de aprendizaje de Whisper.
Podemos enviar automáticamente nuestro punto de control a la tabla de clasificación cuando enviamos los resultados del entrenamiento al Hub, simplemente debemos establecer los argumentos de palabras clave (kwargs) correspondientes. Puede cambiar estos valores para que coincidan con su conjunto de datos, idioma y nombre de modelo:
kwargs = {
"dataset_tags": "mozilla-foundation/common_voice_11_0",
"dataset": "Common Voice 11.0", # un nombre "bonito" para el conjunto de datos de entrenamiento
"dataset_args": "config: hi, split: test",
"language": "hi",
"model_name": "Whisper Small Hi - Sanchit Gandhi", # un nombre "bonito" para su modelo
"finetuned_from": "openai/whisper-small",
"tasks": "automatic-speech-recognition",
"tags": "hf-asr-leaderboard",
}
Los resultados del entrenamiento ahora se pueden cargar en el Hub. Para hacerlo, ejecute el comando push_to_hub
:
trainer.push_to_hub(**kwargs)
Ahora puede compartir este modelo con cualquier persona utilizando el enlace en el Hub. También pueden cargarlo con el identificador "your-username/the-name-you-picked"
, por ejemplo:
from transformers import WhisperForConditionalGeneration, WhisperProcessor
model = WhisperForConditionalGeneration.from_pretrained("sanchit-gandhi/whisper-small-hi")
processor = WhisperProcessor.from_pretrained("sanchit-gandhi/whisper-small-hi")
Aunque el modelo afinado produce resultados satisfactorios en los datos de prueba de Common Voice Hindi, no es óptimo. El propósito de este cuaderno es demostrar cómo se puede afinar finamente los puntos de control pre-entrenados de Whisper en cualquier conjunto de datos ASR multilingüe. Los resultados podrían mejorar optimizando los hiperparámetros de entrenamiento, como la tasa de aprendizaje y la deserción, y utilizando un punto de control pre-entrenado más grande (VoAGI
o large
).
Construyendo una demostración
Ahora que hemos afinado nuestro modelo, ¡podemos construir una demostración para mostrar sus capacidades ASR! Usaremos el pipeline
de 🤗 Transformers, que se encargará de todo el proceso de ASR, desde el preprocesamiento de las entradas de audio hasta la decodificación de las predicciones del modelo. Construiremos nuestra demostración interactiva con Gradio. Gradio es probablemente la forma más sencilla de construir demostraciones de aprendizaje automático; con Gradio, ¡podemos construir una demostración en cuestión de minutos!
Al ejecutar el ejemplo a continuación, se generará una demostración de Gradio donde podemos grabar el habla a través del micrófono de nuestra computadora e ingresarla a nuestro modelo Whisper afinado para transcribir el texto correspondiente:
from transformers import pipeline
import gradio as gr
pipe = pipeline(model="sanchit-gandhi/whisper-small-hi") # cambiar a "your-username/the-name-you-picked"
def transcribe(audio):
text = pipe(audio)["text"]
return text
iface = gr.Interface(
fn=transcribe,
inputs=gr.Audio(source="microphone", type="filepath"),
outputs="text",
title="Whisper Small Hindi",
description="Demo en tiempo real para el reconocimiento del habla en hindi utilizando un modelo Whisper small afinado.",
)
iface.launch()
Comentarios finales
En este blog, cubrimos una guía paso a paso sobre cómo afinar finamente Whisper para ASR multilingüe utilizando 🤗 Datasets, Transformers y el Hugging Face Hub. Consulte el Google Colab si desea probar el afinamiento fino usted mismo. Si está interesado en afinar otros modelos de Transformers, tanto para ASR en inglés como multilingüe, asegúrese de consultar los scripts de ejemplos en examples/pytorch/speech-recognition.