Presentando Würstchen Difusión rápida para la generación de imágenes

Introducing Würstchen Fast Diffusion for Image Generation

¿Qué es Würstchen?

Würstchen es un modelo de difusión cuyo componente condicional de texto funciona en un espacio latente altamente comprimido de imágenes. ¿Por qué es esto importante? Comprimir datos puede reducir los costos computacionales tanto para el entrenamiento como para la inferencia en órdenes de magnitud. El entrenamiento en imágenes de 1024×1024 es mucho más costoso que el entrenamiento en 32×32. Por lo general, otros trabajos utilizan una compresión relativamente pequeña, en el rango de 4x – 8x de compresión espacial. Würstchen lleva esto al extremo. ¡A través de su diseño novedoso, logra una compresión espacial de 42x! Esto nunca se había visto antes, porque los métodos comunes no logran reconstruir fielmente imágenes detalladas después de una compresión espacial de 16x. Würstchen utiliza una compresión de dos etapas, lo que llamamos Etapa A y Etapa B. La Etapa A es un VQGAN y la Etapa B es un Difusor Autoencoder (más detalles se pueden encontrar en el artículo). Juntas, la Etapa A y la Etapa B se llaman el Decodificador, porque decodifican las imágenes comprimidas de vuelta al espacio de píxeles. Un tercer modelo, la Etapa C, se aprende en ese espacio latente altamente comprimido. Este entrenamiento requiere fracciones de la potencia informática utilizada para los modelos de mejor rendimiento actuales, al tiempo que permite una inferencia más barata y rápida. Nos referimos a la Etapa C como el Prior.

¿Por qué otro modelo de texto a imagen?

Bueno, este es bastante rápido y eficiente. ¡Los mayores beneficios de Würstchen provienen del hecho de que puede generar imágenes mucho más rápido que modelos como Stable Diffusion XL, mientras usa mucha menos memoria! Así que para todos nosotros que no tenemos A100s por ahí, esto será muy útil. Aquí hay una comparación con SDXL en diferentes tamaños de lote:

Además de eso, otro beneficio muy significativo de Würstchen viene con la reducción de los costos de entrenamiento. Würstchen v1, que funciona en 512×512, requirió solo 9,000 horas de GPU para entrenar. Comparando esto con las 150,000 horas de GPU gastadas en Stable Diffusion 1.4 sugiere que esta reducción de costos de 16x no solo beneficia a los investigadores al realizar nuevos experimentos, sino que también abre la puerta a que más organizaciones entrenen dichos modelos. Würstchen v2 utilizó 24,602 horas de GPU. Con resoluciones de hasta 1536, esto sigue siendo 6 veces más barato que SD1.4, que solo se entrenó a 512×512.

También puedes encontrar un video de explicación detallado aquí:

¿Cómo usar Würstchen?

Puedes probarlo utilizando la Demo aquí:

De lo contrario, el modelo está disponible a través de la Biblioteca de Difusores, por lo que puedes usar la interfaz con la que ya estás familiarizado. Por ejemplo, así es como se ejecuta la inferencia utilizando el AutoPipeline:

import torch
from diffusers import AutoPipelineForText2Image
from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS

pipeline = AutoPipelineForText2Image.from_pretrained("warp-ai/wuerstchen", torch_dtype=torch.float16).to("cuda")

caption = "Gato antropomórfico vestido como bombero"
images = pipeline(
    caption,
    height=1024,
    width=1536,
    prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,
    prior_guidance_scale=4.0,
    num_images_per_prompt=4,
).images

¿En qué tamaños de imagen funciona Würstchen?

Würstchen fue entrenado en resoluciones de imagen entre 1024×1024 y 1536×1536. A veces también observamos buenos resultados en resoluciones como 1024×2048. Siéntete libre de probarlo. También observamos que el Prior (Etapa C) se adapta extremadamente rápido a nuevas resoluciones. Por lo tanto, afinarlo a 2048×2048 debería ser computacionalmente barato.

Modelos en el Hub

Todos los puntos de control también se pueden ver en el Hub de Huggingface. Se pueden encontrar múltiples puntos de control, así como futuras demos y pesos de modelos. En este momento, hay 3 puntos de control disponibles para el Prior y 1 punto de control para el Decodificador. Echa un vistazo a la documentación donde se explican los puntos de control y para qué se pueden utilizar los diferentes modelos de Prior.

Integración de Diffusers

Porque Würstchen está completamente integrado en diffusers, automáticamente viene con varias ventajas y optimizaciones incluidas. Estas incluyen:

  • Uso automático de la atención acelerada con PyTorch 2 SDPA, como se describe a continuación.
  • Soporte para la implementación de atención flash de xFormers, si necesitas usar PyTorch 1.x en lugar de 2.
  • Descarga del modelo, para mover los componentes no utilizados a la CPU mientras no se estén utilizando. Esto ahorra memoria con un impacto de rendimiento insignificante.
  • Descarga secuencial de la CPU, para situaciones donde la memoria es realmente valiosa. El uso de memoria se minimizará, a costa de una inferencia más lenta.
  • Ponderación de la consulta con la biblioteca Compel.
  • Soporte para el dispositivo mps en las Mac con Apple Silicon.
  • Uso de generadores para reproducibilidad.
  • Valores predeterminados razonables para la inferencia para producir resultados de alta calidad en la mayoría de situaciones. ¡Por supuesto, puedes ajustar todos los parámetros como desees!

Técnica de Optimización 1: Atención Flash

A partir de la versión 2.0, PyTorch ha integrado una versión altamente optimizada y eficiente en recursos del mecanismo de atención llamada torch.nn.functional.scaled_dot_product_attention o SDPA. Dependiendo de la naturaleza de la entrada, esta función aprovecha múltiples optimizaciones subyacentes. Su rendimiento y eficiencia de memoria superan al modelo de atención tradicional. Notablemente, la función SDPA refleja las características de la técnica de atención flash, como se destaca en el artículo de investigación Fast and Memory-Efficient Exact Attention with IO-Awareness escrito por Dao y su equipo.

Si estás utilizando Diffusers con PyTorch 2.0 o una versión posterior, y la función SDPA es accesible, estas mejoras se aplican automáticamente. ¡Comienza configurando torch 2.0 o una versión más nueva siguiendo las guías oficiales!

images = pipeline(caption, height=1024, width=1536, prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS, prior_guidance_scale=4.0, num_images_per_prompt=4).images

Para obtener una visión más detallada de cómo diffusers aprovecha SDPA, consulta la documentación.

Si estás utilizando una versión de PyTorch anterior a la 2.0, aún puedes lograr una atención eficiente en memoria utilizando la biblioteca xFormers:

pipeline.enable_xformers_memory_efficient_attention()

Técnica de Optimización 2: Compilación de Torch

Si estás buscando un impulso adicional en el rendimiento, puedes utilizar torch.compile. Es mejor aplicarlo tanto al modelo principal del prior como al modelo del decodificador para obtener el mayor aumento en el rendimiento.

pipeline.prior_prior = torch.compile(pipeline.prior_prior , mode="reduce-overhead", fullgraph=True)
pipeline.decoder = torch.compile(pipeline.decoder, mode="reduce-overhead", fullgraph=True)

Ten en cuenta que el paso de inferencia inicial llevará mucho tiempo (hasta 2 minutos) mientras se compilan los modelos. Después de eso, puedes ejecutar la inferencia normalmente:

images = pipeline(caption, height=1024, width=1536, prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS, prior_guidance_scale=4.0, num_images_per_prompt=4).images

Y la buena noticia es que esta compilación es una ejecución única. Después de eso, estarás listo para experimentar inferencias más rápidas de manera consistente para las mismas resoluciones de imagen. La inversión inicial de tiempo en la compilación se compensa rápidamente con los beneficios de velocidad posteriores. Para obtener más información sobre torch.compile y sus matices, consulta la documentación oficial.

Recursos

  • Puedes encontrar más información sobre este modelo en la documentación oficial de diffusers.
  • Todos los puntos de control se pueden encontrar en el hub.
  • Puedes probar la demo aquí.
  • ¡Únete a nuestro Discord si quieres discutir proyectos futuros o incluso contribuir con tus propias ideas!