Pérdida NT-Xent (Entropía Cruzada Normalizada Escalada por Temperatura) Explicada e Implementada en PyTorch
NT-Xent loss (Normalized Temperature-scaled Cross Entropy) explained and implemented in PyTorch.
Una explicación intuitiva de la pérdida NT-Xent con una explicación paso a paso de la operación y nuestra implementación en PyTorch
Coescrito con Naresh Singh.

Introducción
Los avances recientes en aprendizaje auto-supervisado y aprendizaje por contraste han emocionado a investigadores y profesionales en el campo del Aprendizaje Automático (ML) para explorar este espacio con un renovado interés.
En particular, el artículo SimCLR que presenta un marco simple para el aprendizaje por contraste de representaciones visuales ha ganado mucha atención en el espacio de auto-supervisión y aprendizaje por contraste.
La idea central detrás del artículo es muy simple: permitir que el modelo aprenda si un par de imágenes se derivó de la misma imagen inicial o de diferentes imágenes iniciales.

El enfoque SimCLR codifica cada imagen de entrada i como un vector de características zi. Hay 2 casos a considerar:
- Desmitificando DreamBooth Una Nueva Herramienta para Personalizar l...
- Conoce a Gorilla LLM aumentado con API de UC Berkeley y Microsoft s...
- Conoce a PANOGEN Un Método Generativo que Potencialmente Puede Crea...
- Pares Positivos: La misma imagen se aumenta utilizando un conjunto diferente de aumentos, y los vectores de características resultantes zi y zj se comparan. Estos vectores de características se ven forzados a ser similares por la función de pérdida.
- Pares Negativos: Diferentes imágenes se aumentan utilizando un conjunto diferente de aumentos, y los vectores de características resultantes zi y zk se comparan. Estos vectores de características se ven forzados a ser diferentes por la función de pérdida.
El resto de este artículo se centrará en explicar y comprender esta función de pérdida, y su implementación eficiente utilizando PyTorch.
La pérdida NT-Xent
En términos generales, el modelo de aprendizaje por contraste recibe 2N imágenes, que provienen de N imágenes subyacentes. Cada una de las N imágenes subyacentes se aumenta utilizando un conjunto aleatorio de aumentos de imagen para producir 2 imágenes aumentadas. Así es como obtenemos 2N imágenes en un único lote de entrenamiento que se alimenta al modelo.
En las siguientes secciones, profundizaremos en los siguientes aspectos de la pérdida NT-Xent.
- El efecto de la temperatura en SoftMax y Sigmoid.
- Una interpretación simple e intuitiva de la pérdida NT-Xent.
- Una implementación paso a paso de NT-Xent en PyTorch.
- Justificación de la necesidad de una función de pérdida multi-etiqueta (NT-BXent).
- Una implementación paso a paso de NT-BXent en PyTorch.
Todo el código para los pasos 2-5 se puede encontrar en este cuaderno. El código para el paso 1 se puede encontrar en este cuaderno.
El efecto de la temperatura en SoftMax y Sigmoid
Para entender todas las partes móviles de la función de pérdida contrastiva que estudiaremos en este artículo, primero debemos entender el efecto de la temperatura en las funciones de activación SoftMax y Sigmoid.
Típicamente, la escala de temperatura se aplica a la entrada de SoftMax o Sigmoid para suavizar o acentuar la salida de esas funciones de activación. Los logits de entrada se dividen por la temperatura antes de pasar a las funciones de activación. Puede encontrar todo el código de esta sección en este cuaderno.
SoftMax: Para SoftMax, una temperatura alta reduce la varianza en la distribución de salida, lo que resulta en la suavización de las etiquetas. Una temperatura baja aumenta la varianza en la distribución de salida y hace que el valor máximo se destaque sobre los demás valores. Vea los gráficos a continuación para ver el efecto de la temperatura en SoftMax cuando se alimenta con el tensor de entrada [0.1081, 0.4376, 0.7697, 0.1929, 0.3626, 2.8451].

Sigmoid: Para Sigmoid, una temperatura alta resulta en una distribución de salida que se acerca a 0.0, mientras que una temperatura baja estira las entradas a valores más altos, estirando las salidas para que estén más cerca de 0.0 o 1.0 dependiendo de la magnitud no firmada de la entrada.

Ahora que entendemos el efecto de varios valores de temperatura en las funciones SoftMax y Sigmoid, veamos cómo esto se aplica a nuestra comprensión de la pérdida NT-Xent.
Interpretando la pérdida NT-Xent
La pérdida NT-Xent se entiende comprendiendo los términos individuales en el nombre de esta pérdida.
- Normalizado: la similitud coseno produce una puntuación normalizada en el rango [-1.0 a +1.0]
- Escala de temperatura: la similitud coseno de todos los pares se escala por una temperatura antes de calcular la pérdida de entropía cruzada
- Pérdida de entropía cruzada: la pérdida subyacente es una pérdida de entropía cruzada de múltiples clases (una etiqueta)
Como se mencionó anteriormente, asumimos que, para un lote de tamaño 2N, los vectores de características en los siguientes índices representan pares positivos (0, 1), (2, 3), (4, 5), (6, 7), … y el resto de las combinaciones representan pares negativos. Este es un factor importante a tener en cuenta durante la interpretación de la pérdida NT-Xent en relación con SimCLR.
Ahora que entendemos lo que significan los términos en el contexto de la pérdida NT-Xent, echemos un vistazo a los pasos mecánicos necesarios para calcular la pérdida NT-Xent en un lote de vectores de características.
- Se calcula la puntuación de similitud coseno para cada uno de los 2N vectores producidos por el modelo SimCLR. Esto resulta en (2N)² puntuaciones de similitud representadas como una matriz 2N x 2N
- Los resultados de la comparación entre el mismo valor (i, i) se descartan (ya que una distribución es perfectamente similar a sí misma y no puede permitir que el modelo aprenda nada útil)
- Cada valor (similitud coseno) se escala por un parámetro de temperatura 𝜏 (que es un hiperparámetro)
- Se aplica la pérdida de entropía cruzada a cada fila de la matriz resultante anterior. El siguiente párrafo explica más en detalle
- Típicamente, se utiliza la media de estas pérdidas (una pérdida por elemento en un lote) para la retropropagación
La forma en que se utiliza la pérdida de entropía cruzada aquí es semánticamente ligeramente diferente de cómo se utiliza en tareas de clasificación estándar. En las tareas de clasificación, se entrena una “cabeza de clasificación” final para producir un vector de probabilidad uno-cero para cada entrada, y se calcula la pérdida de entropía cruzada en ese vector de probabilidad uno-cero ya que estamos efectivamente calculando la diferencia entre 2 distribuciones. Este video explica hermosamente el concepto de pérdida de entropía cruzada. En la pérdida NT-Xent, no hay una correspondencia 1:1 entre una capa entrenable y la distribución de salida. En su lugar, se calcula un vector de características para cada entrada, y luego se calcula la similitud coseno entre cada par de vectores de características. El truco aquí es que como cada imagen es similar a exactamente 1 otra imagen en el lote de entrada (par positivo) (si ignoramos la similitud de un vector de características consigo mismo), podemos considerar esto como un entorno similar a la clasificación donde la distribución de probabilidad de similitud entre imágenes representa una tarea de clasificación donde una de ellas será cercana a 1.0 y el resto estará cerca de 0.0.
Ahora que tenemos una comprensión general sólida de la pérdida NT-Xent, deberíamos estar en buena forma para implementar estas ideas en PyTorch. ¡Empecemos!
Implementación de la pérdida NT-Xent en PyTorch
Todo el código en esta sección se puede encontrar en este cuaderno.
Reutilización de código: Muchas implementaciones de la pérdida NT-Xent vistas en línea implementan todas las operaciones desde cero. Además, algunos de ellos implementan la función de pérdida de manera ineficiente, prefiriendo utilizar bucles for en lugar de la paralelización de la GPU. En cambio, usaremos un enfoque diferente. Implementaremos esta pérdida en términos de la pérdida de entropía cruzada estándar que PyTorch ya proporciona. Para hacer esto, necesitamos transformar las predicciones y las etiquetas de verdad terreno en un formato que la entropía cruzada pueda aceptar. Veamos cómo hacer esto a continuación.
Tensor de Predicciones: Primero, necesitamos crear un tensor de PyTorch que representará la salida de nuestro modelo de aprendizaje contrastivo. Supongamos que nuestro tamaño de lote es 8 (2N=8), y nuestros vectores de características tienen 2 dimensiones (2 valores). Llamaremos a nuestra variable de entrada “x”.
x = torch.randn(8, 2)
Similitud Coseno: A continuación, calcularemos la similitud de coseno de pares para todos los vectores de características en este lote y almacenaremos el resultado en la variable llamada “xcs”. Si la línea de abajo parece confusa, lea los detalles en esta página. Este es el paso de “normalización”.
xcs = F.cosine_similarity(x[None,:,:], x[:,None,:], dim=-1)
Como se mencionó anteriormente, debemos ignorar la puntuación de auto-similitud de cada vector de características ya que no contribuye al aprendizaje del modelo y será una molestia innecesaria más adelante cuando queramos calcular la pérdida de entropía cruzada. Para este propósito, definiremos una variable “eye” que es una matriz con los elementos en la diagonal principal que tienen un valor de 1.0 y el resto siendo 0.0. Podemos crear tal matriz usando el siguiente comando.
eye = torch.eye(8)
Ahora convirtamos esto en una matriz booleana para que podamos indexar en la variable “xcs” usando esta matriz de máscara.
eye = eye.bool()
Clonemos el tensor “xcs” en un tensor llamado “y” para que podamos hacer referencia al tensor “xcs” más tarde.
y = xcs.clone()
Ahora, estableceremos los valores a lo largo de la diagonal principal de la matriz de similitud de coseno de pares en -infinito para que cuando calculemos el softmax en cada fila, este valor no contribuya en nada.
y[eye] = float("-inf")
El tensor “y” escalado por un parámetro de temperatura será una de las entradas (predicciones) a la API de pérdida de entropía cruzada en PyTorch. A continuación, necesitamos calcular las etiquetas de verdad terreno (objetivo) que necesitamos alimentar a la API de pérdida de entropía cruzada.
Etiquetas de verdad terreno (tensor objetivo): Para el ejemplo que estamos usando (2N=8), esto es lo que debería parecer el tensor de verdad terreno.
tensor([1, 0, 3, 2, 5, 4, 7, 6])
Eso es porque los siguientes pares de índices en el tensor “y” contienen pares positivos.
(0, 1), (1, 0)
(2, 3), (3, 2)
(4, 5), (5, 4)
(6, 7), (7, 6)
Para interpretar los pares de índices anteriores, miramos un solo ejemplo. El par (4, 5) significa que la columna 5 en la fila 4 debe establecerse en 1.0 (par positivo), que es lo que el tensor anterior también está diciendo. ¡Genial!
Para crear el tensor anterior, podemos usar el siguiente código de PyTorch, que almacena las etiquetas de verdad terreno en la variable “target”.
target = torch.arange(8)target[0::2] += 1target[1::2] -= 1
Pérdida de entropía cruzada: ¡Tenemos todos los ingredientes que necesitamos para calcular nuestra pérdida! Lo único que queda por hacer es llamar a la API de entropía cruzada en PyTorch.
pérdida = F.cross_entropy(y / temperatura, target, reduction="mean")
La variable “pérdida” ahora contiene la pérdida NT-Xent calculada. Envuelva todo el código en una sola función de Python a continuación.
def nt_xent_loss(x, temperatura): assert len(x.size()) == 2 # similitud coseno xcs = F.cosine_similarity(x[None,:,:], x[:,None,:], dim=-1) xcs[torch.eye(x.size(0)).bool()] = float("-inf") # etiquetas de verdad fundamentales target = torch.arange(8) target[0::2] += 1 target[1::2] -= 1 # pérdida de entropía cruzada estándar return F.cross_entropy(xcs / temperatura, target, reduction="mean")
El código anterior funciona siempre y cuando cada vector de característica tenga exactamente un par positivo en el lote cuando se entrena nuestro modelo de aprendizaje por contraste. Veamos cómo manejar múltiples pares positivos en una tarea de aprendizaje por contraste.
Una pérdida multietiqueta para el aprendizaje por contraste: NT-BXent
En el artículo de SimCLR, cada imagen i tiene exactamente 1 par similar en el índice j. Esto hace que la pérdida de entropía cruzada sea una opción perfecta para la tarea, ya que se asemeja a un problema de múltiples clases. En cambio, si tenemos M> 2 aumentos de la misma imagen alimentados en el lote de entrenamiento único del modelo de aprendizaje por contraste, entonces cada lote tendría M-1 pares similares de imagen para la imagen i. Esta tarea se asemejaría a un problema de múltiples etiquetas.
La elección obvia sería reemplazar la pérdida de entropía cruzada con la pérdida de entropía cruzada binaria. Por lo tanto, el nombre de pérdida NT-BXent, que significa Pérdida de entropía cruzada binaria normalizada y escalada por temperatura.
La formulación a continuación muestra la pérdida Li para el elemento i. La σ en la fórmula a continuación significa la función Sigmoide.
Para evitar el problema del desequilibrio de clases, ponderamos los pares positivos y negativos por el inverso del número de pares positivos y negativos en nuestro mini-lote. La pérdida final en el mini-lote utilizado para la retropropagación será la media de las pérdidas de cada muestra en nuestro mini-lote.
A continuación, enfoquemos nuestra atención en nuestra implementación de la pérdida NT-BXent en PyTorch.
Implementación de la pérdida NT-BXent en PyTorch
Todo el código en esta sección se puede encontrar en este cuaderno.
Reutilización de código: Al igual que en nuestra implementación de la pérdida NT-Xent, reutilizaremos el método de pérdida de entropía cruzada binaria (BCE) proporcionado por PyTorch. La configuración de nuestras etiquetas de verdad será similar a la de un problema de clasificación multietiqueta donde se utiliza la pérdida BCE.
Tensor de predicciones: utilizaremos el mismo tensor de predicciones (8, 2) que usamos para la implementación de la pérdida NT-Xent.
x = torch.randn(8, 2)
Similitud coseno: Dado que el tensor de entrada x es el mismo, el tensor de similitud coseno de todos los pares xcs también será el mismo. Consulte esta página para obtener una explicación detallada de lo que hace la línea a continuación.
xcs = F.cosine_similarity(x[None,:,:], x[:,None,:], dim=-1)
Para garantizar que la pérdida del elemento en la posición (i, i) sea 0, deberemos realizar algunas acrobacias para que nuestro tensor xcs contenga un valor 1 en cada índice (i, i) después de aplicar Sigmoid. A él. Dado que utilizaremos la pérdida BCE, marcaremos el puntaje de auto-similitud de cada vector de característica con el valor infinito en el tensor xcs. Eso se debe a que aplicar la función sigmoide en el tensor xcs, convertirá el infinito en el valor 1, y configuraremos nuestras etiquetas de verdad para que cada posición (i, i) en las etiquetas de verdad tenga el valor 1.
Creemos un tensor de enmascaramiento que tenga el valor Verdadero a lo largo de la diagonal principal ( xcs tiene puntuaciones de auto-similitud a lo largo de la diagonal principal), y Falso en todas partes más.
eye = torch.eye(8).bool()
Clonemos el tensor “xcs” en un tensor llamado “y” para que podamos referenciar el tensor “xcs” más tarde.
y = xcs.clone()
Ahora, estableceremos los valores a lo largo de la diagonal principal de la matriz de similitud coseno de pares a infinito para que cuando calculemos la sigmoidal en cada fila, obtengamos 1 en estas posiciones.
y[eye] = float("inf")
El tensor “y” escalado por un parámetro de temperatura será una de las entradas (predicciones) para la API de pérdida BCE en PyTorch. A continuación, necesitamos calcular las etiquetas de verdad (objetivo) que necesitamos alimentar a la API de pérdida BCE.
Etiquetas de verdad (tensor objetivo) : Esperaremos que el usuario nos pase el par de todos los pares de índices (x, y) que contienen ejemplos positivos. Esto es una partida de lo que hicimos para la pérdida NT-Xent, ya que los pares positivos eran implícitos, mientras que aquí, los pares positivos son explícitos.
Además de las ubicaciones proporcionadas por el usuario, estableceremos todos los elementos diagonales como pares positivos como se explicó anteriormente. Usaremos la API de indexación de tensor PyTorch para extraer todos los elementos en esas ubicaciones y establecerlos en 1, mientras que el resto se inicializa en 0.
target = torch.zeros(8, 8)pos_indices = torch.tensor([ (0, 0), (0, 2), (0, 4), (1, 4), (1, 6), (1, 1), (2, 3), (3, 7), (4, 3), (7, 6),])# Agregar índices de la diagonal principal como índices positivos.# Esto será útil ya que usaremos BCELoss en PyTorch,# que esperará un valor para los elementos en la diagonal principal también.pos_indices = torch.cat([pos_indices, torch.arange(8).reshape(8, 1).expand(-1, 2)], dim=0)# Establecer los valores en el vector objetivo en 1.target[pos_indices[:,0], pos_indices[:,1]] = 1
Pérdida de entropía cruzada binaria (BCE) : A diferencia de la pérdida NT-Xent, no podemos simplemente llamar a la función torch.nn.functional.binary_cross_entropy_function, ya que queremos ponderar la pérdida positiva y negativa en función de cuántos pares positivos y negativos tiene el elemento en el índice i en el mini-lote actual.
El primer paso es calcular la pérdida de entropía cruzada elemento a elemento.
temperature = 0.1loss = F.binary_cross_entropy((y / temperature).sigmoid(), target, reduction="none")
Crearemos una máscara binaria de pares positivos y negativos y luego crearemos 2 tensores, loss_pos y loss_neg que contengan solo aquellos elementos de la pérdida calculada que correspondan a los pares positivos y negativos.
target_pos = target.bool()target_neg = ~target_pos# loss_pos y loss_neg a continuación contienen valores no nulos solo para aquellos elementos# que son pares positivos y negativos respectivamente.loss_pos = torch.zeros(x.size(0), x.size(0)).masked_scatter(target_pos, loss[target_pos])loss_neg = torch.zeros(x.size(0), x.size(0)).masked_scatter(target_neg, loss[target_neg])
A continuación, sumaremos la pérdida de pares positivos y negativos (por separado) correspondiente a cada elemento i en nuestro mini-lote.
# loss_pos y loss_neg ahora contienen la suma de las pérdidas de pares positivos y negativos# como se calcula en relación con la entrada i.loss_pos = loss_pos.sum(dim=1)loss_neg = loss_neg.sum(dim=1)
Para realizar el ponderado, necesitamos hacer un seguimiento del número de pares positivos y negativos correspondientes a cada elemento i en nuestro mini-lote. Los tensores “num_pos” y “num_neg” almacenarán estos valores.
# num_pos y num_neg a continuación contienen el número de pares positivos y negativos# calculados en relación con la entrada i. En un entorno real, este número debería# ser el mismo para cada elemento de entrada, pero lo dejamos variar aquí para obtener la máxima# flexibilidad.num_pos = target.sum(dim=1)num_neg = target.size(0) - num_pos
¡Tenemos todos los ingredientes que necesitamos para calcular nuestra pérdida! Lo único que tenemos que hacer es ponderar la pérdida positiva y negativa por el número de pares positivos y negativos, y luego promediar la pérdida en el mini-batch.
def nt_bxent_loss(x, pos_indices, temperature): assert len(x.size()) == 2 # Add indexes of the principal diagonal elements to pos_indices pos_indices = torch.cat([ pos_indices, torch.arange(x.size(0)).reshape(x.size(0), 1).expand(-1, 2), ], dim=0) # Ground truth labels target = torch.zeros(x.size(0), x.size(0)) target[pos_indices[:,0], pos_indices[:,1]] = 1.0 # Cosine similarity xcs = F.cosine_similarity(x[None,:,:], x[:,None,:], dim=-1) # Set logit of diagonal element to "inf" signifying complete # correlation. sigmoid(inf) = 1.0 so this will work out nicely # when computing the Binary cross-entropy Loss. xcs[torch.eye(x.size(0)).bool()] = float("inf") # Standard binary cross-entropy loss. We use binary_cross_entropy() here and not # binary_cross_entropy_with_logits() because of # https://github.com/pytorch/pytorch/issues/102894 # The method *_with_logits() uses the log-sum-exp-trick, which causes inf and -inf values # to result in a NaN result. loss = F.binary_cross_entropy((xcs / temperature).sigmoid(), target, reduction="none") target_pos = target.bool() target_neg = ~target_pos loss_pos = torch.zeros(x.size(0), x.size(0)).masked_scatter(target_pos, loss[target_pos]) loss_neg = torch.zeros(x.size(0), x.size(0)).masked_scatter(target_neg, loss[target_neg]) loss_pos = loss_pos.sum(dim=1) loss_neg = loss_neg.sum(dim=1) num_pos = target.sum(dim=1) num_neg = x.size(0) - num_pos return ((loss_pos / num_pos) + (loss_neg / num_neg)).mean()pos_indices = torch.tensor([ (0, 0), (0, 2), (0, 4), (1, 4), (1, 6), (1, 1), (2, 3), (3, 7), (4, 3), (7, 6),])for t in (0.01, 0.1, 1.0, 10.0, 20.0): print(f"Temperatura: {t:5.2f}, Pérdida: {nt_bxent_loss(x, pos_indices, temperature=t)}")
Imprime:
Temperatura: 0.01, Pérdida: 62.898780822753906
Temperatura: 0.10, Pérdida: 4.851151943206787
Temperatura: 1.00, Pérdida: 1.0727109909057617
Temperatura: 10.00, Pérdida: 0.9827173948287964
Temperatura: 20.00, Pérdida: 0.982099175453186
Conclusión
El aprendizaje auto-supervisado es un campo emergente en el aprendizaje profundo y nos permite entrenar modelos en datos no etiquetados. Esta técnica nos permite trabajar en la necesidad de datos etiquetados a gran escala.
En este artículo, aprendimos sobre las funciones de pérdida para el aprendizaje contrastivo. La primera, llamada pérdida NT-Xent, se utiliza para el aprendizaje en un solo par positivo por entrada en un mini-batch. Presentamos la pérdida NT-BXent que se utiliza para el aprendizaje en múltiples (> 1) pares positivos por entrada en un mini-batch. Aprendimos a interpretarlos intuitivamente, basándonos en nuestro conocimiento de la pérdida de entropía cruzada y la pérdida de entropía cruzada binaria. Por último, las implementamos eficientemente en PyTorch.