Graph Neural Networks
En los últimos años hemos presenciado un incremento en la popularidad de conceptos como aprendizaje profundo, redes neuronales, etcétera. Este hecho se ha producido debido a un fácil acceso a una mayor cantidad de datos y capacidad computacional. Hoy vamos a hablar de las Graph Neural Networks en este artículo.
Como podemos ver en el siguiente gráfico, cuando se dispone de muchos datos, las redes neuronales profundas mejoran con diferencia los algoritmos de aprendizaje automático (Machine Learning) clásicos.
Estas técnicas han obtenido resultados asombrosos en áreas tan complejas como la visión por computador o el procesamiento del lenguaje natural, ofreciendo soluciones a problemas que se creían intratables. Para cada campo, se han creado modelos de aprendizaje profundo los cuales introducen un conocimiento previo sobre la estructura de los datos. Por ejemplo, las redes neuronales convolucionales aprovechan la estructura de la imagen y la relación entre píxeles vecinos, lo que permite (i) definir características locales, (ii) reducir la dimensionalidad del problema y (iii) reducir el número de parámetros.
Actualmente, hay un incremento de aplicaciones cuyos datos no son estructurados y necesitan modelar interacciones entre objetos como pueden ser los usuarios de una red social, el enlace de átomos para formar moléculas, etc. Estas interacciones se modelan mediante grafos y los modelos que han conseguido el estado del arte en un gran número de aplicaciones son las Graph Neural Networks (GNNs). Para entender qué son las GNNs y cuáles son sus aplicaciones, debemos entender primero qué es un grafo, qué información contiene y cómo se representa.
Grafo
En términos generales, un grafo G es una estructura de datos que sirve para modelar interacciones o conexiones entre objetos, representados como nodos. Formalmente, un grafo G es una tupla ( V, E ) donde V es el conjunto de nodos (objetos a modelar) y E es el conjunto de interacciones entre esos nodos. El ejemplo más conocido e intuitivo son las redes sociales como Facebook, donde cada usuario representa un nodo del grafo. Por lo tanto, si dos usuarios son amigos en esa plataforma, hay una conexión entre esos dos nodos.
Matriz de Adyacencia
Es común representar un grafo de N nodos mediante su matriz de adyacencia (A). Ésta, es una matriz NxN, es decir, con una fila/columna para cada nodo y cuyas entradas Aij valen 1 si hay una conexión entre los nodos i y j, o 0 si no la hay:
En la siguiente imagen podemos ver cómo sería la matriz de adyacencia para este grafo de 6 nodos. Como podéis observar, esta matriz es simétrica. Ésto se debe a que el grafo es unidireccional, es decir, las conexiones no tienen dirección, dos nodos o están conectados o no lo están. Si hubiera dirección, en general la matriz A no sería simétrica.
Es posible también que no todas las conexiones del grafo sean igual de importantes y por lo tanto haya unos “pesos” asignados a cada elemento de E. Ésto implicaría que las entradas de A podrían tomar otros valores además de 0s y 1s.
Graph Neural Networks (GNNs)
Bueno, al lío. Vamos a ver qué son y qué hacen este tipo de redes neuronales. La verdad es que es un campo en expansión y nuevas definiciones e implementaciones van saliendo frecuentemente. Inicialmente, debido al gran éxito de las redes neuronales convolucionales en imágenes y con el objetivo de extender esa operación a grafos, había, entre otras, dos líneas de investigación: una espectral cuya intención era utilizar propiedades de la transformada de Fourier y una espacial, basada en teoría de grafos más clásica como el algoritmo de Weisfeiler-Lehman.
En este artículo hablaremos de uno de los frameworks más populares a día de hoy, el cual se basa en un proceso de envío de mensajes entre nodos (message passing) y engloba gran parte de las definiciones espectrales y espaciales. A partir de ahora lo llamaremos MPNN (del inglés Message Passing Neural Networks).
La idea principal de las MPNN es la de, en cada capa de la red neuronal, aprender nuevas características/estados/representaciones/etc para todos los nodos del grafo. Para ello, se lleva a cabo el envío de mensajes entre éstos, en el cual cada nodo (i) envía un mensaje a sus vecinos (con quién está conectado) basado en su estado actual, (ii) recibe y agrega todos mensajes de sus vecinos y (iii) actualiza su representación basándose en su estado actual y la agregación. Por lo tanto, en la capa l, cada nodo realiza el siguiente cálculo:
La función MENSAJE(·) es básicamente una red neuronal, la cual toma como entrada la representación de un nodo y calcula un mensaje de salida. Podría ser que también tomase más información, como la representación del nodo de destino v, el peso entre el nodo u y el nodo v, etcétera.
La función AGREGAR(·), debe ser una función que no requiera un número fijo de elementos de entrada (no todos los nodos tienen el mismo número de vecinos) ni del orden (no debe importar el orden de los mensajes recibidos). Los ejemplos más claros son la suma y la media de mensajes.
Finalmente, la función ACTUALIZAR(·) es la que toma la agregación de todos los mensajes y actualiza el estado del nodo de destino. Ésta podría ir de algo tan simple como la función identidad (no hacer nada) a una función de activación o incluso otra red neuronal.
Vamos a ver un ejemplo gráfico, en el cual tenemos una sola capa de MPNN. Si nos fijamos solamente en el nodo B (tened en cuenta que esto se calcula para todos los nodos en cada capa), el resultado es el siguiente: Los vecinos de B son los nodos A, C y D. Con cada uno de ellos se calcula un mensaje a través de la función MENSAJE(·), en azul. El nodo B recibe los tres mensajes y los agrega siguiendo la estrategia marcada por AGREGAR(·), en verde. Finalmente, el nuevo estado del nodo B se calcula mediante una última función ACTUALIZAR(·), en rojo.
Figura 2: Comparación entre Machine Learning y Deep Learning en función de la cantidad de datos disponibles. Fuente: Canadian Association of Radiologists White Paper on Artificial Intelligence in Radiology
Estas técnicas han obtenido resultados asombrosos en áreas tan complejas como la visión por computador o el procesamiento del lenguaje natural, ofreciendo soluciones a problemas que se creían intratables. Para cada campo, se han creado modelos de aprendizaje profundo los cuales introducen un conocimiento previo sobre la estructura de los datos. Por ejemplo, las redes neuronales convolucionales aprovechan la estructura de la imagen y la relación entre píxeles vecinos, lo que permite (i) definir características locales, (ii) reducir la dimensionalidad del problema y (iii) reducir el número de parámetros.
Figura 3: Libro de Machine Learning aplicado a Ciberseguridad de Carmen Torrano, Fran Ramírez, Paloma Recuero, José Torres y Santiago Hernández. |
Grafo
En términos generales, un grafo G es una estructura de datos que sirve para modelar interacciones o conexiones entre objetos, representados como nodos. Formalmente, un grafo G es una tupla ( V, E ) donde V es el conjunto de nodos (objetos a modelar) y E es el conjunto de interacciones entre esos nodos. El ejemplo más conocido e intuitivo son las redes sociales como Facebook, donde cada usuario representa un nodo del grafo. Por lo tanto, si dos usuarios son amigos en esa plataforma, hay una conexión entre esos dos nodos.
Matriz de Adyacencia
Es común representar un grafo de N nodos mediante su matriz de adyacencia (A). Ésta, es una matriz NxN, es decir, con una fila/columna para cada nodo y cuyas entradas Aij valen 1 si hay una conexión entre los nodos i y j, o 0 si no la hay:
Figura 4: Definición de la matriz de adyacencia
En la siguiente imagen podemos ver cómo sería la matriz de adyacencia para este grafo de 6 nodos. Como podéis observar, esta matriz es simétrica. Ésto se debe a que el grafo es unidireccional, es decir, las conexiones no tienen dirección, dos nodos o están conectados o no lo están. Si hubiera dirección, en general la matriz A no sería simétrica.
Figura 5 : Ejemplo de grafo y su correspondiente matriz de adyacencia
Es posible también que no todas las conexiones del grafo sean igual de importantes y por lo tanto haya unos “pesos” asignados a cada elemento de E. Ésto implicaría que las entradas de A podrían tomar otros valores además de 0s y 1s.
Graph Neural Networks (GNNs)
Bueno, al lío. Vamos a ver qué son y qué hacen este tipo de redes neuronales. La verdad es que es un campo en expansión y nuevas definiciones e implementaciones van saliendo frecuentemente. Inicialmente, debido al gran éxito de las redes neuronales convolucionales en imágenes y con el objetivo de extender esa operación a grafos, había, entre otras, dos líneas de investigación: una espectral cuya intención era utilizar propiedades de la transformada de Fourier y una espacial, basada en teoría de grafos más clásica como el algoritmo de Weisfeiler-Lehman.
En este artículo hablaremos de uno de los frameworks más populares a día de hoy, el cual se basa en un proceso de envío de mensajes entre nodos (message passing) y engloba gran parte de las definiciones espectrales y espaciales. A partir de ahora lo llamaremos MPNN (del inglés Message Passing Neural Networks).
La idea principal de las MPNN es la de, en cada capa de la red neuronal, aprender nuevas características/estados/representaciones/etc para todos los nodos del grafo. Para ello, se lleva a cabo el envío de mensajes entre éstos, en el cual cada nodo (i) envía un mensaje a sus vecinos (con quién está conectado) basado en su estado actual, (ii) recibe y agrega todos mensajes de sus vecinos y (iii) actualiza su representación basándose en su estado actual y la agregación. Por lo tanto, en la capa l, cada nodo realiza el siguiente cálculo:
Figura 6: Fórmula del proceso Message Passing
La función MENSAJE(·) es básicamente una red neuronal, la cual toma como entrada la representación de un nodo y calcula un mensaje de salida. Podría ser que también tomase más información, como la representación del nodo de destino v, el peso entre el nodo u y el nodo v, etcétera.
La función AGREGAR(·), debe ser una función que no requiera un número fijo de elementos de entrada (no todos los nodos tienen el mismo número de vecinos) ni del orden (no debe importar el orden de los mensajes recibidos). Los ejemplos más claros son la suma y la media de mensajes.
Finalmente, la función ACTUALIZAR(·) es la que toma la agregación de todos los mensajes y actualiza el estado del nodo de destino. Ésta podría ir de algo tan simple como la función identidad (no hacer nada) a una función de activación o incluso otra red neuronal.
Vamos a ver un ejemplo gráfico, en el cual tenemos una sola capa de MPNN. Si nos fijamos solamente en el nodo B (tened en cuenta que esto se calcula para todos los nodos en cada capa), el resultado es el siguiente: Los vecinos de B son los nodos A, C y D. Con cada uno de ellos se calcula un mensaje a través de la función MENSAJE(·), en azul. El nodo B recibe los tres mensajes y los agrega siguiendo la estrategia marcada por AGREGAR(·), en verde. Finalmente, el nuevo estado del nodo B se calcula mediante una última función ACTUALIZAR(·), en rojo.
Figura 7: Ejemplo gráfico del proceso de Message Passing
Debemos añadir un par de notas para completar nuestra explicación de las GNNs. MENSAJE(·) normalmente es implementada mediante una red neuronal, la cual es la misma para todos los nodos, es decir, se comparten los pesos (al igual que en una CNN). Además, en este ejemplo gráfico solo se toma la representación del nodo origen como entrada para calcular el mensaje, pero tal y como he comentado previamente, también podría incluirse información del nodo destino, del link, etcétera.
Finalmente, la fórmula que hemos visto es individual para cada nodo. Realmente, según qué cómo se definan las funciones MENSAJE(·), AGREGAR(·) y ACTUALIZAR(·), este proceso puede implementarse como una multiplicación de matrices, lo cual actualiza todos los nodos de golpe e implica menos consumo de memoria y recursos computacionales.
Y todo esto… ¿Para qué sirve?
Tal y como hemos visto, una GNN aprende representaciones de los nodos de un grafo a partir de un proceso de envío de mensajes entre vecinos. La siguiente pregunta que debemos responder es, ¿cómo se utilizan esas representaciones? En imágenes, podemos llevar a cabo tareas a distintos niveles.
Por ejemplo, una clasificación de imágenes calcula un único valor para cada imagen, mientras que una segmentación requiere el cálculo de un valor para cada píxel de la imagen de entrada. Pues sucede exactamente lo mismo para el caso de grafos, podemos distinguir tres niveles principales:
1. Nivel de nodo: Para una clasificación o regresión de nodos, se aplicaría el predictor directamente a la representación de cada nodo obtenida a partir de la GNN. Un ejemplo de esta aplicación podría ser la predicción de qué usuarios de Facebook son votantes de izquierdas o de derechas.2. Nivel de link: Para poder hacer una predicción sobre la conexión entre dos nodos, la entrada del modelo debería ser la combinación de la representación de ambos (con una concatenación, suma, etc). Por ejemplo, podríamos calcular la probabilidad de que dos usuarios de Facebook se conozcan en la vida real.3. Nivel de grafo: Este caso es el menos directo, ya que es necesario obtener una única representación para todo el grafo a partir de las obtenidas para cada nodo. A esta agregación global, dentro del framework MPNN se le llama READOUT(·). Tal y como en la función AGREGAR(·) de MPNN, esta función no debe tener un número fijo de entrada (cada grafo puede tener distinto número de nodos) ni depender del orden de entrada. Un ejemplo popular es el cálculo de la probabilidad de que una molécula (modelado como un grafo donde los nodos son átomos) sea un buen medicamento.
Aplicaciones
En este apartado hablaremos de aplicaciones de las GNNs, algunas de ellas ya nos afectan directamente con funcionalidades que utilizamos durante nuestro día a día.
1.-Predicción de tráfico: DeepMind colaboró con Google Maps para crear un modelo de predicción de tráfico y tiempo estimado de trayecto basado en GNNs en ciertas ciudades del mundo. Para lograrlo, representaron el sistema de carreteras de estas ciudades como grandes grafos, siendo cada nodo un segmento de la vía con características de velocidad y longitud, y conectando aquellos nodos que representan segmentos adyacentes.
Figura 8: GNNs para la predicción de tráfico.
3. Sistemas de recomendación: Los sistemas de recomendación deben decidir qué producto o contenido puede ser de interés para un usuario. Esto se traduce a un problema a nivel de link entre un nodo del tipo usuario y uno del tipo producto. Una de las primeras compañías en usar esta tecnología a gran escala fue Pinterest para recomendar “pins” a sus usuarios. Posteriormente, empresas como Alibaba y Amazon también han incorporado modelos basados en GNNs en sus plataformas.
4. Detección de Fake News: En 2018, Michael Bronstein y algunos de sus estudiantes lanzaron la start-up Fabula.ai, la cual fue adquirida por Twitter en 2019. Su objetivo era el de detectar Fake-News en redes sociales partiendo de la hipótesis de que estas se esparcen entre los usuarios (nodos de la red) de forma distinta a las noticias reales. Como podréis imaginar, las GNNs, mediante el proceso de message-passing, son bastante buenas modelando este proceso de difusión de la información, como el que podría representar el siguiente GIF:
Figura 9: Ejemplo de difusión en un grafo o red. Simulating Network Diffusion with R
Con las aplicaciones acabamos nuestro pequeño artículo sobre Graph Neural Networks. Esperamos que os haya gustado y os haya sido útil como introducción a este maravilloso campo. Actualmente hay muchas líneas de investigación abiertas como el aprendizaje del propio grafo: ¿Qué pasa si no hay una estructura de grafo implícita en nuestros datos? ¿Podemos aprender una para utilizar GNNs?.
¡Seguro que las GNNs darán mucho de qué hablar los próximos años! ¡Saludos!
Autores: Bruno Ibáñez, Investigador de Ciberseguridad e IA en Ideas Locas y Oscar Pina doctorando en Geometric Deep Learning y Graph Representation Learning
No hay comentarios:
Publicar un comentario