Investigadores de Google DeepMind presentan SynJax una biblioteca de aprendizaje profundo para la distribución de probabilidades estructuradas en JAX.
Google DeepMind presenta SynJax, una biblioteca de aprendizaje profundo en JAX para distribución de probabilidades estructuradas.
Los datos pueden ser vistos como teniendo una estructura en diversas áreas que explica cómo sus componentes se ajustan para formar un todo más grande. Dependiendo de la actividad, esta estructura generalmente es latente y cambia. Considere la Figura 1 para ilustraciones de estructuras distintas en el lenguaje natural. Juntas, las palabras forman una secuencia. A cada palabra en una secuencia se le aplica una etiqueta de parte de la oración. Estas etiquetas están interconectadas, generando una cadena lineal de color rojo. Al segmentar la oración, que está representada con burbujas, las palabras en la oración se pueden unir en grupos pequeños, desarticulados y contiguos. Un examen más exhaustivo del lenguaje revelaría que se pueden formar grupos de manera recursiva, creando una estructura de árbol sintáctico. Las estructuras también pueden conectar dos idiomas.
Por ejemplo, en la misma imagen, una alineación puede vincular una traducción en japonés con una fuente en inglés. Estas construcciones gramaticales son universales. En biología, se pueden encontrar estructuras similares. Los modelos basados en árboles de ARN capturan el aspecto jerárquico del proceso de plegado de proteínas, mientras que la alineación monótona se utiliza para emparejar los nucleótidos en secuencias de ARN. Los datos genómicos también se dividen en grupos contiguos. La mayoría de los modelos actuales de aprendizaje profundo no intentan representar explícitamente la estructura intermedia y en su lugar buscan predecir las variables de salida directamente a partir de la entrada. Estos modelos podrían beneficiarse de una modelización explícita de la estructura de varias maneras. El uso de los sesgos inductivos adecuados podría facilitar una mejor generalización. Esto mejoraría el rendimiento posterior además de la eficiencia de la muestra.
La modelización explícita de la estructura puede incorporar un conjunto de restricciones o métodos específicos del problema. Las decisiones tomadas por el modelo también son más fácilmente comprensibles debido a la estructura discreta. Por último, hay ocasiones en las que la estructura es el resultado del propio aprendizaje. Por ejemplo, pueden ser conscientes de que los datos se explican mediante una estructura oculta de cierta forma, pero necesitan saber más al respecto. Para modelar secuencias, los modelos autorregresivos son la técnica predominante. En algunas situaciones, las estructuras no secuenciales se pueden linealizar y aproximar mediante una estructura secuencial. Estos modelos son fuertes porque no se basan en suposiciones independientes y se pueden entrenar utilizando una gran cantidad de datos. Si bien identificar la estructura ideal o marginalizar las variables ocultas son problemas comunes de inferencia, el muestreo de modelos autorregresivos a menudo no es viable.
El uso de modelos autorregresivos en modelos a gran escala es desafiante porque requieren aproximaciones sesgadas o de alta varianza, que a menudo son computacionalmente costosas. Los modelos sobre grafos de factores que se factorizan de la misma manera que la estructura objetivo son una alternativa a los modelos autorregresivos. Estos modelos pueden calcular de manera precisa y eficiente todos los problemas de inferencia interesantes mediante el uso de métodos especializados. Aunque cada estructura requiere un método único, cada tarea de inferencia no requiere un algoritmo especializado (argmax, muestreo, marginales, entropía, etc.). Para extraer varios números de una sola función para cada tipo de estructura, SynJax utiliza la diferenciación automática, como demostrarán más adelante.
- Asalto Digital” Recaptura la Piedra Rosetta
- Reinventando la Utopía Comunidades Autocreadas para la Era Digital
- ¿Por qué es importante la escala de características en el aprendiza...
La falta de bibliotecas prácticas que ofrezcan implementaciones compatibles con aceleradores de componentes estructurales ha frenado la investigación en distribuciones estructuradas para el entendimiento profundo, especialmente porque estos componentes dependen de algoritmos que con frecuencia no se ajustan directamente a los primitivos de aprendizaje profundo disponibles, a diferencia de los modelos Transformer. Los investigadores de Google Deepmind ofrecen primitivas estructurales fáciles de usar que se combinan dentro del marco de aprendizaje automático JAX, ayudando a SynJax a resolver el desafío. Considere el ejemplo en la Figura 2 para demostrar lo fácil que es usar SynJax. Este código implementa una pérdida de gradiente de políticas que requiere calcular varios parámetros, incluyendo muestreo, argmax, entropía y probabilidad logarítmica, cada uno de los cuales requiere un enfoque separado.
La estructura es un árbol de expansión dirigido no proyectivo con una restricción de una sola arista raíz en esta línea de código. Como resultado, SynJax empleará el enfoque de muestreo de Wilson dist.sample() para árboles de una sola raíz, dist.entropy() y el algoritmo del árbol de expansión máxima de Tarjan para árboles de una sola arista raíz. Los árboles de una sola arista raíz pueden utilizar el Teorema de la Matriz-Árbol. Solo se necesita cambiar una bandera para que SynJax use algoritmos completamente diferentes que sean adecuados para esa estructura: el algoritmo de Kuhlmann para argmax y varias iteraciones del algoritmo de Eisner para otras cantidades, si solo desean alterar levemente el tipo de árboles al imponer la restricción de proyectividad como usuarios. Debido a que SynJax se encarga de todo lo relacionado con dichos algoritmos, el usuario puede concentrarse en el aspecto de modelado de su problema sin implementarlos ni siquiera entender cómo funcionan.