Implementando matemáticas en documentos de aprendizaje profundo en un código eficiente de PyTorch Pérdida contrastiva SimCLR

Implementación eficiente de matemáticas en documentos de aprendizaje profundo con PyTorch y pérdida contrastiva en SimCLR.

Aprendiendo a implementar fórmulas matemáticas avanzadas en código PyTorch eficiente.

Foto de Jeswin Thomas en Unsplash

Introducción

Una de las mejores formas de profundizar en la comprensión de las matemáticas detrás de los modelos de aprendizaje profundo y las funciones de pérdida, y también una excelente manera de mejorar tus habilidades en PyTorch, es acostumbrarse a implementar por ti mismo los documentos de aprendizaje profundo.

Los libros y las publicaciones en blogs pueden ayudarte a comenzar a programar y aprender los conceptos básicos en ML / DL, pero después de estudiar unos cuantos y volverte bueno en las tareas rutinarias del campo, pronto te darás cuenta de que estás solo en el viaje de aprendizaje y encontrarás la mayoría de los recursos en línea aburridos y superficiales. Sin embargo, creo que si puedes estudiar nuevos documentos de aprendizaje profundo a medida que se publican y entender las partes matemáticas requeridas en ellos (no necesariamente todas las pruebas matemáticas detrás de las teorías de los autores) y eres un programador capaz que puede implementarlas en un código eficiente, nada puede impedirte mantenerte al día en el campo y aprender nuevas ideas.

Implementación de la pérdida contrastiva

Presentaré mi rutina y los pasos que sigo para implementar las matemáticas en documentos de aprendizaje profundo utilizando un ejemplo no trivial: la pérdida contrastiva en el documento SimCLR.

Aquí está la formulación matemática de la pérdida:

Pérdida contrastiva (NT-Xent) del documento SimCLR | de https://arxiv.org/pdf/2002.05709.pdf

¡Estoy de acuerdo en que la mera apariencia de la fórmula puede ser intimidante! Y es posible que estés pensando que debe haber muchas implementaciones de PyTorch disponibles en GitHub, así que vamos a usarlas 🙂 y sí, tienes razón. Hay docenas de implementaciones en línea. Sin embargo, creo que este es un buen ejemplo para practicar esta habilidad y podría servir como un buen punto de partida.

Pasos para implementar matemáticas en código

Mi rutina para implementar las matemáticas en los documentos en código PyTorch eficiente es la siguiente:

  1. Comprender las matemáticas y explicarlas en términos simples
  2. Implementar una versión inicial usando simples bucles “for” de Python, sin multiplicaciones de matrices sofisticadas por ahora
  3. Convertir tu código en un código eficiente y amigable con las matrices de PyTorch

OK, vamos directamente al primer paso.

Paso 1: Comprender las matemáticas y explicarlas en términos simples

Supongo que tienes conocimientos básicos de álgebra lineal y estás familiarizado con las notaciones matemáticas. Si no lo estás, puedes usar esta herramienta para saber qué representa cada uno de estos símbolos y qué hacen en matemáticas, simplemente dibujando el símbolo. También puedes consultar esta increíble página de Wikipedia donde se describen la mayoría de las notaciones. Estas son las oportunidades en las que aprendes cosas nuevas, buscando y leyendo lo que se necesita en el momento en que lo necesitas. Creo que es una forma más eficiente de aprender, en lugar de comenzar con un libro de matemáticas desde cero y dejarlo de lado después de unos días 🙂

Volviendo a nuestro tema. Como el párrafo anterior a la fórmula agrega más contexto, en la estrategia de aprendizaje SimCLR comienzas con N imágenes y las transformas 2 veces para obtener vistas aumentadas de esas imágenes (ahora son 2 * N imágenes). Luego, pasas estas 2 * N imágenes a través de un modelo para obtener vectores de incrustación para cada una de ellas. Ahora, quieres acercar los vectores de incrustación de las 2 vistas aumentadas de la misma imagen (un par positivo) en el espacio de incrustación (y hacer lo mismo para todos los demás pares positivos). Una forma de medir qué tan similares (cercanos, en la misma dirección) son dos vectores es usando la similitud del coseno, que se define como sim(u, v) (busca la definición en la imagen de arriba).

En términos simples, lo que la fórmula describe es que para cada elemento en nuestro lote, que es la incrustación de una de las vistas aumentadas de una imagen, (Recuerda: el lote contiene todas las incrustaciones de las vistas aumentadas de diferentes imágenes → si comenzamos con N imágenes, el lote tiene un tamaño de 2*N), primero encontramos la incrustación de la otra vista aumentada de esa imagen para hacer una pareja positiva. Luego, calculamos la similitud del coseno de estas dos incrustaciones y la exponentiamos (el numerador de la fórmula). Luego, calculamos la exponenciación de la similitud del coseno de todas las otras parejas que podemos construir con nuestro primer vector de incrustación con el que comenzamos (excepto la pareja consigo misma, esto es lo que significa 1[k!=i] en la fórmula), y las sumamos para construir el denominador. Ahora, podemos dividir el numerador por el denominador y tomar el logaritmo natural de eso y ¡cambiar el signo! Ahora tenemos la pérdida del primer elemento en nuestro lote. Solo necesitamos repetir el mismo proceso para todos los demás elementos en el lote y luego tomar el promedio para poder llamar al método .backward() de PyTorch para calcular los gradientes.

Paso 2: ¡Implementándolo usando código Python simple, con bucles “for” ingenuos!

Implementación simple en Python, usando bucles “for” lentos

Repasemos el código. Digamos que tenemos dos imágenes: A y B. La variable aug_views_1 contiene las incrustaciones (cada una de tamaño 3) de una vista aumentada de estas dos imágenes (A1 y B1), al igual que aug_views_2 (A2 y B2); por lo tanto, el primer elemento en ambas matrices se relaciona con la imagen A y el segundo elemento de ambas se relaciona con la imagen B. Concatenamos las dos matrices en la matriz de proyecciones (que tiene 4 vectores en ella: A1, B1, A2, B2).

Para mantener la relación de los vectores en la matriz de proyecciones, definimos un diccionario de pares_positivos para almacenar qué dos elementos están relacionados en la matriz concatenada. (¡pronto explicaré lo del F.normalize()!)

Como puedes ver en las siguientes líneas de código, recorro los elementos en la matriz de proyecciones en un bucle for, encuentro el vector relacionado usando nuestro diccionario y luego calculo la similitud del coseno. Puede que te preguntes por qué no divides por el tamaño de los vectores, como sugiere la fórmula de la similitud del coseno. El punto es que antes de comenzar el bucle, utilizando la función F.normalize, normalizo todos los vectores en nuestra matriz de proyección para que tengan un tamaño de 1. Por lo tanto, no es necesario dividir por el tamaño en la línea donde calculamos la similitud del coseno.

Después de construir nuestro numerador, encuentro todos los demás índices de vectores en el lote (excepto el mismo índice i), para calcular las similitudes del coseno que componen el denominador. Finalmente, calculo la pérdida dividiendo el numerador por el denominador y aplicando la función logarítmica y cambiando el signo. Asegúrate de jugar con el código para entender qué sucede en cada línea.

Paso 3: Convirtiéndolo en código PyTorch eficiente y amigable con las matrices

El problema con la implementación anterior en Python es que es demasiado lenta para ser utilizada en nuestro flujo de entrenamiento; necesitamos deshacernos de los bucles lentos “for” y convertirlo en multiplicaciones de matrices y manipulaciones de arreglos para aprovechar la potencia de la paralelización.

Implementación en PyTorch

Veamos qué sucede en este fragmento de código. Esta vez, he introducido los tensores labels_1 y labels_2 para codificar las clases arbitrarias a las que pertenecen estas imágenes, ya que necesitamos una forma de codificar la relación de las imágenes A1, A2 y B1, B2. No importa si eliges las etiquetas 0 y 1 (como hice) o digamos 5 y 8.

Después de concatenar tanto las incrustaciones como las etiquetas, comenzamos creando una matriz de similitud que contiene la similitud del coseno de todas las posibles parejas.

Cómo se ve la matriz de similitud: las celdas verdes contienen nuestras parejas positivas, las celdas naranjas son las parejas que deben ignorarse en el denominador | Visualización del autor

La visualización anterior es todo lo que necesitas 🙂 para entender cómo funciona el código y por qué estamos realizando los pasos en él. Considerando la primera fila de sim_matrix, podemos calcular la pérdida para el primer elemento en el lote (A1) de la siguiente manera: necesitamos dividir A1A2 (exponenciado) por la suma de A1B1, A1A2 y A1B2 (cada uno exponenciado primero) y guardar el resultado en el primer elemento de un tensor que almacena todas las pérdidas. Por lo tanto, primero debemos crear una máscara para encontrar las celdas verdes en la visualización anterior. Las dos líneas de código que definen la variable máscara hacen exactamente esto. El numerador se calcula multiplicando nuestra sim_matrix por la máscara que acabamos de crear, y luego sumando los elementos de cada fila (después de aplicar la máscara, solo habrá un elemento distinto de cero en cada fila, es decir, las celdas verdes). Para calcular el denominador, debemos sumar sobre cada fila, ignorando las celdas naranjas de la diagonal. Para hacer esto, utilizaremos el método .diag() de los tensores de PyTorch. ¡El resto es autoexplicativo!

Bono: Usar asistentes de IA (ChatGPT, Copilot, …) para implementar la fórmula

Disponemos de excelentes herramientas para ayudarnos a entender e implementar las matemáticas en los documentos de aprendizaje profundo. Por ejemplo, puedes pedirle a ChatGPT (u otras herramientas similares) que implemente el código en PyTorch después de darle la fórmula del documento. En mi experiencia, ChatGPT puede ser de gran ayuda y proporcionar las mejores respuestas finales con menos ensayo y error si puedes llegar de alguna manera a la implementación ingenua del bucle for en Python. Dale esa implementación ingenua a ChatGPT y pídele que la convierta en un código eficiente de PyTorch que solo use multiplicaciones de matrices y manipulaciones de tensores; te sorprenderá la respuesta 🙂

Lectura adicional

Te animo a que consultes las siguientes dos excelentes implementaciones de la misma idea para aprender cómo puedes extender esta implementación considerando situaciones más matizadas, como en el entorno de aprendizaje contrastivo supervisado.

  1. Supervised Contrastive Loss, de Guillaume Erhard
  2. SupContrast, de Yonglong Tian

Sobre mí

Soy Moein Shariatnia, un desarrollador de aprendizaje automático y estudiante de medicina, centrado en utilizar soluciones de aprendizaje profundo para aplicaciones de imágenes médicas. Mi investigación se centra principalmente en investigar la generalización de los modelos profundos en diversas circunstancias. No dudes en contactarme por correo electrónico, Twitter o LinkedIn.