Ajustar finamente ViT para la clasificación de imágenes con 🤗 Transformers
Ajuste fino de ViT para clasificación de imágenes con 🤗 Transformers.
Así como los modelos basados en transformadores han revolucionado el procesamiento del lenguaje natural, ahora estamos presenciando una explosión de artículos que los aplican a todo tipo de otros dominios. Uno de los más revolucionarios de estos fue el Vision Transformer (ViT), que fue introducido en junio de 2021 por un equipo de investigadores de Google Brain.
Este artículo exploró cómo se pueden tokenizar imágenes, al igual que se tokenizan las oraciones, para que puedan ser pasadas a modelos de transformadores para el entrenamiento. Es un concepto bastante simple, en realidad…
- Dividir una imagen en una cuadrícula de subimágenes
- Incrustar cada subimagen con una proyección lineal
- Cada subimagen incrustada se convierte en un token, y la secuencia resultante de subimágenes incrustadas es la secuencia que se pasa al modelo.
Resulta que una vez que hayas hecho lo anterior, puedes preentrenar y ajustar los transformadores tal como estás acostumbrado con las tareas de procesamiento del lenguaje natural. Bastante genial 😎.
En esta entrada de blog, veremos cómo aprovechar 🤗 datasets
para descargar y procesar conjuntos de datos de clasificación de imágenes, y luego usarlos para ajustar finamente un ViT preentrenado con 🤗 transformers
.
- Expertos en Aprendizaje Automático – Margaret Mitchell
- Presentando Decision Transformers en Hugging Face 🤗
- ~No te repitas~
Para comenzar, instalemos ambos paquetes.
pip install datasets transformers
Cargar un conjunto de datos
Comencemos cargando un pequeño conjunto de datos de clasificación de imágenes y echando un vistazo a su estructura.
Usaremos el conjunto de datos beans
, que es una colección de imágenes de hojas de frijol sanas y enfermas. 🍃
from datasets import load_dataset
ds = load_dataset('beans')
ds
Echemos un vistazo al ejemplo número 400 del conjunto de datos 'train'
del conjunto de datos de frijoles. Observarás que cada ejemplo del conjunto de datos tiene 3 características:
image
: Una imagen PILimage_file_path
: La rutastr
al archivo de imagen que se cargó comoimage
labels
: Una característicadatasets.ClassLabel
, que es una representación entera de la etiqueta. (¡Más adelante verás cómo obtener los nombres de las clases en forma de cadena, no te preocupes!)
ex = ds['train'][400]
ex
{
'image': <PIL.JpegImagePlugin ...>,
'image_file_path': '/root/.cache/.../bean_rust_train.4.jpg',
'labels': 1
}
Echemos un vistazo a la imagen 👀
image = ex['image']
image
¡Definitivamente es una hoja! Pero, ¿qué tipo? 😅
Dado que la característica 'labels'
de este conjunto de datos es una datasets.features.ClassLabel
, podemos usarla para buscar el nombre correspondiente a la ID de etiqueta de este ejemplo.
Primero, accedamos a la definición de la característica 'labels'
.
labels = ds['train'].features['labels']
labels
ClassLabel(num_classes=3, names=['angular_leaf_spot', 'bean_rust', 'healthy'], names_file=None, id=None)
Ahora, imprimamos la etiqueta de clase de nuestro ejemplo. Puedes hacerlo utilizando la función int2str
de ClassLabel
, que, como su nombre lo indica, permite pasar la representación entera de la clase para buscar la etiqueta en forma de cadena.
labels.int2str(ex['labels'])
'bean_rust'
Resulta que la hoja mostrada anteriormente está infectada con Bean Rust, una enfermedad grave en las plantas de frijol. 😢
Escribamos una función que muestre una cuadrícula de ejemplos de cada clase para tener una mejor idea de con qué estás trabajando.
import random
from PIL import ImageDraw, ImageFont, Image
def show_examples(ds, seed: int = 1234, examples_per_class: int = 3, size=(350, 350)):
w, h = size
labels = ds['train'].features['labels'].names
grid = Image.new('RGB', size=(examples_per_class * w, len(labels) * h))
draw = ImageDraw.Draw(grid)
font = ImageFont.truetype("/usr/share/fonts/truetype/liberation/LiberationMono-Bold.ttf", 24)
for label_id, label in enumerate(labels):
# Filtrar el conjunto de datos por una sola etiqueta, mezclarlo y tomar algunos ejemplos
ds_slice = ds['train'].filter(lambda ex: ex['labels'] == label_id).shuffle(seed).select(range(examples_per_class))
# Graficar los ejemplos de esta etiqueta en una fila
for i, example in enumerate(ds_slice):
image = example['image']
idx = examples_per_class * label_id + i
box = (idx % examples_per_class * w, idx // examples_per_class * h)
grid.paste(image.resize(size), box=box)
draw.text(box, label, (255, 255, 255), font=font)
return grid
show_examples(ds, seed=random.randint(0, 1337), examples_per_class=3)
Una cuadrícula de algunos ejemplos de cada clase en el conjunto de datos
Por lo que estoy viendo,
- Mancha angular de las hojas: Tiene manchas marrones irregulares
- Roya de frijol: Tiene manchas marrones circulares rodeadas de un anillo amarillento blanquecino
- Sano: …se ve saludable. 🤷♂️
Cargando Extractor de Características ViT
Ahora sabemos cómo se ven nuestras imágenes y comprendemos mejor el problema que estamos tratando de resolver. ¡Veamos cómo podemos preparar estas imágenes para nuestro modelo!
Cuando se entrenan modelos ViT, se aplican transformaciones específicas a las imágenes que se les proporcionan. ¡Si aplicas las transformaciones incorrectas a tu imagen, el modelo no entenderá lo que está viendo! 🖼 ➡️ 🔢
Para asegurarnos de aplicar las transformaciones correctas, utilizaremos un ViTFeatureExtractor
inicializado con una configuración que se guardó junto con el modelo pre-entrenado que planeamos utilizar. En nuestro caso, usaremos el modelo ‘google/vit-base-patch16-224-in21k’, así que carguemos su extractor de características desde Hugging Face Hub.
from transformers import ViTFeatureExtractor
model_name_or_path = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path)
Puedes ver la configuración del extractor de características imprimiéndola.
ViTFeatureExtractor {
"do_normalize": true,
"do_resize": true,
"feature_extractor_type": "ViTFeatureExtractor",
"image_mean": [
0.5,
0.5,
0.5
],
"image_std": [
0.5,
0.5,
0.5
],
"resample": 2,
"size": 224
}
Para procesar una imagen, simplemente pásala a la función de llamada del extractor de características. Esto devolverá un diccionario que contiene valores de píxeles
, que es la representación numérica que se pasará al modelo.
Obtendrás un array de NumPy por defecto, pero si agregas el argumento return_tensors='pt'
, obtendrás tensores torch
en su lugar.
feature_extractor(image, return_tensors='pt')
Esto debería darte algo como…
{
'pixel_values': tensor([[[[ 0.2706, 0.3255, 0.3804, ...]]]])
}
…donde la forma del tensor es (1, 3, 224, 224)
.
Procesando el Conjunto de Datos
Ahora que sabes cómo leer imágenes y transformarlas en entradas, escribamos una función que junte esas dos cosas para procesar un solo ejemplo del conjunto de datos.
def process_example(example):
inputs = feature_extractor(example['image'], return_tensors='pt')
inputs['labels'] = example['labels']
return inputs
process_example(ds['train'][0])
{
'pixel_values': tensor([[[[-0.6157, -0.6000, -0.6078, ...]]]]),
'labels': 0
}
Aunque podrías llamar a ds.map
y aplicar esto a cada ejemplo de una vez, esto puede ser muy lento, especialmente si usas un conjunto de datos más grande. En cambio, puedes aplicar una transformación al conjunto de datos. Las transformaciones solo se aplican a los ejemplos cuando los indexas.
Primero, sin embargo, deberás actualizar la última función para aceptar un lote de datos, ya que eso es lo que espera ds.with_transform
.
ds = load_dataset('beans')
def transform(example_batch):
# Toma una lista de imágenes PIL y conviértelas en valores de píxeles
inputs = feature_extractor([x for x in example_batch['image']], return_tensors='pt')
# ¡No olvides incluir las etiquetas!
inputs['labels'] = example_batch['labels']
return inputs
Puedes aplicar esto directamente al conjunto de datos usando ds.with_transform(transform)
.
prepared_ds = ds.with_transform(transform)
Ahora, cada vez que obtengas un ejemplo del conjunto de datos, la transformación se aplicará en tiempo real (tanto en muestras como en segmentos, como se muestra a continuación)
prepared_ds['train'][0:2]
Esta vez, el tensor resultante pixel_values
tendrá forma (2, 3, 224, 224)
.
{
'pixel_values': tensor([[[[-0.6157, -0.6000, -0.6078, ..., ]]]]),
'labels': [0, 0]
}
Los datos están procesados y estás listo para comenzar a configurar el proceso de entrenamiento. Esta publicación de blog utiliza el Trainer de 🤗, pero primero debemos hacer algunas cosas:
-
Definir una función de agrupamiento (collate function).
-
Definir una métrica de evaluación. Durante el entrenamiento, el modelo debe evaluarse en base a su precisión de predicción. Debes definir una función
compute_metrics
en consecuencia. -
Cargar un punto de control previamente entrenado. Necesitas cargar un punto de control previamente entrenado y configurarlo correctamente para el entrenamiento.
-
Definir la configuración de entrenamiento.
Después de ajustar el modelo, lo evaluarás correctamente en los datos de evaluación y verificarás que haya aprendido a clasificar correctamente las imágenes.
Definir nuestro agrupador de datos (data collator)
Los lotes (batches) llegan como listas de diccionarios, por lo que simplemente puedes desempaquetarlos y apilarlos en tensores de lote.
Dado que la función collate_fn
devuelve un diccionario de lote, puedes desempaquetar los datos de entrada al modelo más adelante utilizando **
. ✨
import torch
def collate_fn(batch):
return {
'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
'labels': torch.tensor([x['labels'] for x in batch])
}
Definir una métrica de evaluación
La métrica de precisión de datasets
se puede utilizar fácilmente para comparar las predicciones con las etiquetas. A continuación, puedes ver cómo utilizarla dentro de una función compute_metrics
que será utilizada por el Trainer
.
import numpy as np
from datasets import load_metric
metric = load_metric("accuracy")
def compute_metrics(p):
return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)
Carguemos el modelo preentrenado. Agregaremos num_labels
en la inicialización para que el modelo cree una capa de clasificación con el número correcto de unidades. También incluiremos los mapeos id2label
y label2id
para tener etiquetas legibles por humanos en el widget del Hub (si eliges push_to_hub
).
from transformers import ViTForImageClassification
labels = ds['train'].features['labels'].names
model = ViTForImageClassification.from_pretrained(
model_name_or_path,
num_labels=len(labels),
id2label={str(i): c for i, c in enumerate(labels)},
label2id={c: str(i) for i, c in enumerate(labels)}
)
¡Casi listo para entrenar! Lo último que se necesita antes de eso es configurar la configuración de entrenamiento mediante la definición de TrainingArguments
.
La mayoría de estas opciones son bastante autoexplicativas, pero una que es bastante importante aquí es remove_unused_columns=False
. Esta opción descartará cualquier característica que no sea utilizada por la función de llamada del modelo. De forma predeterminada, es True
porque generalmente es ideal descartar columnas de características no utilizadas, lo que facilita el desempaquetado de las entradas en la función de llamada del modelo. Pero, en nuestro caso, necesitamos las características no utilizadas (‘image’ en particular) para crear ‘pixel_values’.
Lo que estoy tratando de decir es que tendrás problemas si olvidas establecer remove_unused_columns=False
.
from transformers import TrainingArguments
training_args = TrainingArguments(
output_dir="./vit-base-beans",
per_device_train_batch_size=16,
evaluation_strategy="steps",
num_train_epochs=4,
fp16=True,
save_steps=100,
eval_steps=100,
logging_steps=10,
learning_rate=2e-4,
save_total_limit=2,
remove_unused_columns=False,
push_to_hub=False,
report_to='tensorboard',
load_best_model_at_end=True,
)
Ahora, todas las instancias se pueden pasar al Entrenador y estamos listos para comenzar el entrenamiento!
from transformers import Trainer
entrenador = Trainer(
model=modelo,
args=args_entrenamiento,
data_collator=colador_datos,
compute_metrics=calcular_metricas,
train_dataset=ds_entrenamiento_preparado["train"],
eval_dataset=ds_entrenamiento_preparado["validacion"],
tokenizer=extractor_caracteristicas,
)
Entrenamiento 🚀
resultados_entrenamiento = entrenador.train()
entrenador.guardar_modelo()
entrenador.registrar_metricas("entrenamiento", resultados_entrenamiento.metricas)
entrenador.guardar_metricas("entrenamiento", resultados_entrenamiento.metricas)
entrenador.guardar_estado()
Evaluación 📊
metricas = entrenador.evaluar(ds_entrenamiento_preparado['validacion'])
entrenador.registrar_metricas("evaluacion", metricas)
entrenador.guardar_metricas("evaluacion", metricas)
Aquí están los resultados de mi evaluación – ¡Genial! Lo siento, tenía que decirlo.
***** métricas de evaluación *****
epoch = 4.0
eval_accuracy = 0.985
eval_loss = 0.0637
eval_runtime = 0:00:02.13
eval_samples_per_second = 62.356
eval_steps_per_second = 7.97
Por último, si quieres, puedes subir tu modelo al hub. Aquí, lo subiremos si especificaste push_to_hub=True
en la configuración de entrenamiento. Ten en cuenta que para subir al hub, tendrás que tener git-lfs instalado y haber iniciado sesión en tu cuenta de Hugging Face (lo cual se puede hacer a través de huggingface-cli login
).
kwargs = {
"finetuned_from": modelo.configuracion._nombre_o_ruta,
"tareas": "clasificación-de-imagen",
"conjunto_de_datos": 'beans',
"etiquetas": ['clasificación-de-imagen'],
}
if args_entrenamiento.push_to_hub:
entrenador.subir_al_hub('🍻 salud', **kwargs)
else:
entrenador.crear_tarjeta_del_modelo(**kwargs)
El modelo resultante se ha compartido en nateraw/vit-base-beans . Supongo que no tienes imágenes de hojas de frijol por ahí, ¡así que agregué algunos ejemplos para que los pruebes! 🚀