Hugging Face en PyTorch / XLA TPUs

Hugging Face en PyTorch / XLA TPUs.

Entrenamiento de tus Transformers favoritos en TPUs en la nube usando PyTorch / XLA

El proyecto PyTorch-TPU se originó como un esfuerzo de colaboración entre los equipos de Facebook PyTorch y Google TPU y se lanzó oficialmente en la Conferencia de Desarrolladores de PyTorch 2019. Desde entonces, hemos trabajado con el equipo de Hugging Face para brindar soporte de primera clase para el entrenamiento en TPUs en la nube utilizando PyTorch / XLA. Esta nueva integración permite a los usuarios de PyTorch ejecutar y escalar sus modelos en TPUs en la nube mientras mantienen la misma interfaz de entrenadores de Hugging Face.

Esta publicación de blog proporciona una descripción general de los cambios realizados en la biblioteca de Hugging Face, lo que hace la biblioteca de PyTorch / XLA, un ejemplo para comenzar a entrenar tus Transformers favoritos en TPUs en la nube y algunos puntos de referencia de rendimiento. Si no puedes esperar para comenzar con las TPUs, dirígete a la sección “Entrena tu Transformer en TPUs en la nube” – ¡nosotros nos encargamos de todas las mecánicas de PyTorch / XLA por ti dentro del módulo Trainer!

Tipo de dispositivo XLA:TPU

PyTorch / XLA agrega un nuevo tipo de dispositivo xla a PyTorch. Este tipo de dispositivo funciona igual que otros tipos de dispositivo de PyTorch. Por ejemplo, así es como se crea e imprime un tensor XLA:

import torch
import torch_xla
import torch_xla.core.xla_model as xm

t = torch.randn(2, 2, device=xm.xla_device())
print(t.device)
print(t)

Este código debería resultar familiar. PyTorch / XLA utiliza la misma interfaz que PyTorch regular con algunas adiciones. La importación de torch_xla inicializa PyTorch / XLA, y xm.xla_device() devuelve el dispositivo XLA actual. Este puede ser una CPU, GPU o TPU según tu entorno, pero en esta publicación de blog nos centraremos principalmente en TPU.

El módulo Trainer utiliza una clase de datos TrainingArguments para definir los detalles del entrenamiento. Maneja múltiples argumentos, desde tamaños de lote, tasa de aprendizaje, acumulación de gradientes y otros, hasta los dispositivos utilizados. Basado en lo anterior, en TrainingArguments._setup_devices() cuando se utilizan dispositivos XLA:TPU, simplemente devolvemos el dispositivo TPU que se usará por el Trainer:

@dataclass
class TrainingArguments:
    ...
    @cached_property
    @torch_required
    def _setup_devices(self) -> Tuple["torch.device", int]:
        ...
        elif is_torch_tpu_available():
            device = xm.xla_device()
            n_gpu = 0
        ...

        return device, n_gpu

Cálculo de pasos del dispositivo XLA

En un escenario típico de entrenamiento XLA:TPU, estamos entrenando en múltiples núcleos TPU en paralelo (un solo dispositivo Cloud TPU incluye 8 núcleos TPU). Por lo tanto, debemos asegurarnos de que todos los gradientes se intercambien entre las réplicas paralelas de datos mediante la consolidación de los gradientes y la toma de un paso del optimizador. Para esto, proporcionamos xm.optimizer_step(optimizer) que realiza la consolidación de gradientes y el paso del optimizador. En el entrenador de Hugging Face, actualizamos el paso de entrenamiento para utilizar las API de PyTorch / XLA de la siguiente manera:

class Trainer:
…
   def train(self, *args, **kwargs):
       ...
                    if is_torch_tpu_available():
                        xm.optimizer_step(self.optimizer)

Tubería de entrada PyTorch / XLA

Hay dos partes principales para ejecutar un modelo PyTorch / XLA: (1) trazar y ejecutar el gráfico del modelo de manera perezosa (consulte la sección “Biblioteca PyTorch / XLA” a continuación para obtener una explicación más detallada) y (2) alimentar tu modelo. Sin ninguna optimización, el trazado/ejecución de tu modelo y la alimentación de datos se ejecutarían en serie, dejando intervalos de tiempo en los que tanto la CPU de tu host como tus aceleradores TPU estarían inactivos, respectivamente. Para evitar esto, proporcionamos una API que combina las dos partes y, por lo tanto, puede superponer el trazado del paso n+1 mientras el paso n aún se está ejecutando.

import torch_xla.distributed.parallel_loader as pl
...
  dataloader = pl.MpDeviceLoader(dataloader, device)

Escritura y carga de puntos de control

Cuando un tensor se guarda desde un dispositivo XLA y luego se carga desde el punto de control, se cargará de nuevo en el dispositivo original. Antes de guardar tensores en su modelo, desea asegurarse de que todos sus tensores estén en dispositivos CPU en lugar de dispositivos XLA. De esta manera, cuando cargue los tensores nuevamente, los cargará a través de dispositivos CPU y luego tendrá la oportunidad de colocarlos en los dispositivos XLA que desee. Proporcionamos la API xm.save() para esto, que se encarga de escribir solo en una ubicación de almacenamiento desde solo un proceso en cada host (o uno globalmente si se utiliza un sistema de archivos compartido entre hosts).

class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
…
    def save_pretrained(self, save_directory):
        ...
        if getattr(self.config, "xla_device", False):
            import torch_xla.core.xla_model as xm

            if xm.is_master_ordinal():
                # Guardar archivo de configuración
                model_to_save.config.save_pretrained(save_directory)
            # xm.save se encarga de guardar solo desde el maestro
            xm.save(state_dict, output_model_file)

class Trainer:
…
   def train(self, *args, **kwargs):
       ...
       if is_torch_tpu_available():
           xm.rendezvous("saving_optimizer_states")
           xm.save(self.optimizer.state_dict(),
                   os.path.join(output_dir, "optimizer.pt"))
           xm.save(self.lr_scheduler.state_dict(),
                   os.path.join(output_dir, "scheduler.pt"))

Biblioteca PyTorch / XLA

PyTorch / XLA es un paquete de Python que utiliza el compilador de álgebra lineal XLA para conectar el marco de aprendizaje profundo PyTorch con dispositivos XLA, que incluyen CPU, GPU y Cloud TPUs. Una parte del siguiente contenido también está disponible en nuestro API_GUIDE.md.

Los tensores de PyTorch / XLA son perezosos

Usar tensores y dispositivos XLA solo requiere cambiar algunas líneas de código. Sin embargo, aunque los tensores de XLA actúan de manera similar a los tensores de CPU y CUDA, sus internos son diferentes. Los tensores de CPU y CUDA lanzan operaciones de inmediato o de forma ansiosa. Los tensores de XLA, por otro lado, son perezosos. Registran las operaciones en un grafo hasta que se necesiten los resultados. Deferir la ejecución de esta manera permite que XLA la optimice. Un grafo de múltiples operaciones separadas puede fusionarse en una sola operación optimizada.

La ejecución perezosa generalmente es invisible para el llamador. PyTorch / XLA construye automáticamente los gráficos, los envía a los dispositivos XLA y se sincroniza al copiar datos entre un dispositivo XLA y la CPU. Insertar una barrera al realizar un paso de optimización sincroniza explícitamente la CPU y el dispositivo XLA.

Esto significa que cuando llame a model(input) para el pase de ejecución hacia adelante, calcule su pérdida loss.backward() y realice un paso de optimización xm.optimizer_step(optimizer), el grafo de todas las operaciones se construye en segundo plano. Solo cuando evalúe explícitamente el tensor (por ejemplo, imprimir el tensor o moverlo a un dispositivo CPU) o marque un paso (esto lo hará el MpDeviceLoader cada vez que lo recorra), se ejecutará el paso completo.

Rastrear, compilar, ejecutar y repetir

Desde el punto de vista del usuario, un régimen de entrenamiento típico para un modelo que se ejecuta en PyTorch / XLA implica ejecutar un pase hacia adelante, un pase hacia atrás y un paso de optimización. Desde el punto de vista de la biblioteca PyTorch / XLA, las cosas se ven un poco diferentes.

Mientras un usuario ejecuta sus pases hacia adelante y hacia atrás, se traza un gráfico de representación intermedia (IR) sobre la marcha. El gráfico de IR que lleva a cada tensor raíz/salida se puede inspeccionar de la siguiente manera:

>>> import torch
>>> import torch_xla
>>> import torch_xla.core.xla_model as xm
>>> t = torch.tensor(1, device=xm.xla_device())
>>> s = t*t
>>> print(torch_xla._XLAC._get_xla_tensors_text([s]))
IR {
  %0 = s64[] prim::Constant(), value=1
  %1 = s64[] prim::Constant(), value=0
  %2 = s64[] xla::as_strided_view_update(%1, %0), size=(), stride=(), storage_offset=0
  %3 = s64[] aten::as_strided(%2), size=(), stride=(), storage_offset=0
  %4 = s64[] aten::mul(%3, %3), ROOT=0
}

Este gráfico en vivo se acumula mientras se ejecutan los pases hacia adelante y hacia atrás en el programa del usuario, y una vez que se llama a xm.mark_step() (indirectamente por pl.MpDeviceLoader), se corta el gráfico de tensores en vivo. Esta truncación marca la finalización de un paso y posteriormente reducimos el gráfico IR a Operaciones de Nivel Superior XLA (HLO), que es el lenguaje IR para XLA.

Este gráfico HLO se compila en un binario TPU y posteriormente se ejecuta en los dispositivos TPU. Sin embargo, este paso de compilación puede ser costoso, generalmente lleva más tiempo que un solo paso, por lo que si compiláramos el programa del usuario en cada paso, el costo sería alto. Para evitar esto, tenemos caches que almacenan binarios TPU compilados ordenados por los identificadores únicos de los gráficos HLO. Entonces, una vez que esta caché de binarios TPU se ha poblado en el primer paso, los pasos posteriores típicamente no tendrán que volver a compilar nuevos binarios TPU; en su lugar, simplemente pueden buscar los binarios necesarios en la caché.

Dado que las compilaciones TPU suelen ser mucho más lentas que el tiempo de ejecución del paso, esto significa que si el gráfico sigue cambiando de forma, tendremos fallos en la caché y compilaremos con demasiada frecuencia. Para minimizar los costos de compilación, recomendamos mantener las formas de los tensores estáticas siempre que sea posible. Las formas de la biblioteca Hugging Face ya son estáticas en su mayor parte, con los tokens de entrada siendo ajustados correctamente, por lo que a lo largo del entrenamiento la caché debería ser golpeada consistentemente. Esto se puede verificar utilizando las herramientas de depuración que proporcionan PyTorch / XLA. En el siguiente ejemplo, se puede ver que la compilación solo ocurrió 5 veces (CompileTime) mientras que la ejecución ocurrió en cada uno de los 1220 pasos (ExecuteTime):

>>> import torch_xla.debug.metrics as met
>>> print(met.metrics_report())
Metric: CompileTime
  TotalSamples: 5
  Accumulator: 28s920ms153.731us
  ValueRate: 092ms152.037us / segundo
  Rate: 0.0165028 / segundo
  Percentiles: 1%=428ms053.505us; 5%=428ms053.505us; 10%=428ms053.505us; 20%=03s640ms888.060us; 50%=03s650ms126.150us; 80%=11s110ms545.595us; 90%=11s110ms545.595us; 95%=11s110ms545.595us; 99%=11s110ms545.595us
Metric: DeviceLockWait
  TotalSamples: 1281
  Accumulator: 38s195ms476.007us
  ValueRate: 151ms051.277us / segundo
  Rate: 4.54374 / segundo
  Percentiles: 1%=002.895us; 5%=002.989us; 10%=003.094us; 20%=003.243us; 50%=003.654us; 80%=038ms978.659us; 90%=192ms495.718us; 95%=208ms893.403us; 99%=221ms394.520us
Metric: ExecuteTime
  TotalSamples: 1220
  Accumulator: 04m22s555ms668.071us
  ValueRate: 923ms872.877us / segundo
  Rate: 4.33049 / segundo
  Percentiles: 1%=045ms041.018us; 5%=213ms379.757us; 10%=215ms434.912us; 20%=217ms036.764us; 50%=219ms206.894us; 80%=222ms335.146us; 90%=227ms592.924us; 95%=231ms814.500us; 99%=239ms691.472us
Counter: CachedCompile
  Value: 1215
Counter: CreateCompileHandles
  Value: 5
...

Entrena tu transformador en Cloud TPUs

Para configurar tu VM y Cloud TPUs, sigue las secciones “Configurar una instancia de Motor de Cómputo” y “Lanzar un recurso de Cloud TPU” (versión pytorch-1.7 al momento de escribir). Una vez que hayas creado tu VM y Cloud TPU, usarlos es tan simple como hacer SSH a tu VM de GCE y ejecutar los siguientes comandos para iniciar el entrenamiento de bert-large-uncased (el tamaño del lote es para un dispositivo v3-8, puede quedarse sin memoria en un v2-8):

conda activate torch-xla-1.7
export TPU_IP_ADDRESS="INGRESE_SU_DIRECCIÓN_IP_DE_TPU"  # ej. 10.0.0.2
export XRT_TPU_CONFIG="tpu_worker;0;$TPU_IP_ADDRESS:8470"
git clone -b v4.2.2 https://github.com/huggingface/transformers.git
cd transformers && pip install .
pip install datasets==1.2.1
python examples/xla_spawn.py \
  --num_cores 8 \
  examples/language-modeling/run_mlm.py \
  --dataset_name wikitext \
  --dataset_config_name wikitext-103-raw-v1 \
  --max_seq_length 512 \
  --pad_to_max_length \
  --logging_dir ./tensorboard-metrics \
  --cache_dir ./cache_dir \
  --do_train \
  --do_eval \
  --overwrite_output_dir \
  --output_dir language-modeling \
  --overwrite_cache \
  --tpu_metrics_debug \
  --model_name_or_path bert-large-uncased \
  --num_train_epochs 3 \
  --per_device_train_batch_size 8 \
  --per_device_eval_batch_size 8 \
  --save_steps 500000

Lo anterior debería completar el entrenamiento en aproximadamente menos de 200 minutos con una perplejidad de evaluación de ~3.25.

Benchmarks de rendimiento

La siguiente tabla muestra el rendimiento del entrenamiento de bert-large-uncased en un sistema TPU en la nube v3-8 (que contiene 4 chips TPU v3) ejecutando PyTorch / XLA. El conjunto de datos utilizado para todas las mediciones de referencia es el conjunto de datos WikiText103, y utilizamos el script run_mlm.py proporcionado en los ejemplos de Hugging Face. Para asegurarnos de que las cargas de trabajo no estén limitadas por la CPU del host, utilizamos la configuración de CPU n1-standard-96 para estas pruebas, pero es posible que también pueda utilizar configuraciones más pequeñas sin afectar el rendimiento.

Comenzar con PyTorch / XLA en TPUs

Consulte la sección “Ejecución en TPUs” en los ejemplos de Hugging Face para comenzar. Para obtener una descripción más detallada de nuestras APIs, consulte nuestra API_GUIDE y, para obtener las mejores prácticas de rendimiento, consulte nuestra guía de TROUBLESHOOTING. Para obtener ejemplos genéricos de PyTorch / XLA, ejecute los siguientes Cuadernos de Colab que ofrecemos con acceso gratuito a Cloud TPU. Para ejecutar directamente en GCP, consulte nuestros tutoriales etiquetados como “PyTorch” en nuestro sitio de documentación.

¿Tiene alguna otra pregunta o problema? Abra un problema o pregunta en https://github.com/huggingface/transformers/issues o directamente en https://github.com/pytorch/xla/issues.