top of page
Search

The flow of generative networks

Writer: Om KumarOm Kumar

Recently, I have been trying to generate drug samples using generative flow architectures and by now I have gotten accustomed to my prof. nodding in utter disappointment at the samples that my models generate. Which makes me question the fact that when models like Dall-E and stable diffusion can generate such a wide variety of images, what is the bottleneck in trying to generate chemical molecules from a given sample of similar drugs.


But what are "Generative Flow Networks" a.k.a. GFlowNets


Generative Flow Networks are quite a recent phenomena, and have been inspired from Reinforcement Learning and Deep Learning and personally for me, it was quite a task to grasp the concept behind them. Primarily because GFlowNets are not just an independent concept, but rather a mixture of a host of different Machine Learning Concepts. One of them being Graph Neural Networks. So I feel it would suffice to kinda delve into Graph Neural Networks here and maybe cover GFlowNets in another one, we'll see.


Graph Neural Networks:

It would be a very good idea to start with what a graph actually is. Now to most of us, they might seem abstract, but graphs pop up almost everywhere you can find entities(nodes) that are related among themselves(depicted by edges) via some pre-defined notion.


Now normal graphs and networks can only represent relations to a certain extent, and we can additionally specialize them via the concept of directed and undirected edges. We have obviously read and seen graph data in context of social networks or citation networks(Scientists citing each other) but there are a few interesting places where graphs tend to yield some insightful patterns and ideas when used.


For instance Images, now obviously using graphs to represent images sounds totally absurd and useless. Because Images have a very nice structured pattern to them, that is one reason why they are arranged in 2d or 3d arrays as bands. But picture this, what if we were to represent every pixel in the image as nodes with the adjacent pixel(maybe even of the different color channel) forming the neighbors that are connected via appropriate edges. Although very redundant, this actually paints a very nice picture if you think about it in a particular way. There is a representation for graphs that is quite commonly known as the adjacency matrix, which is a way of representing nodes of a graph and their connections.


So say there are 25 pixels in an image, we order them in a 5 x 5 matrix and fill the entries in the matrix such that they represent the edges shared between two nodes i.e. pixels in the case of an image.


Now it doesn't matter whether you love math or graphs or not, If you have a soul, you have to appreciate the underlying patterns that are popping up here. But there is only so much beauty one can appreciate, because when the question of efficiency comes, this is clearly not the best choice out there.


Going beyond Beauty:

In the wild though, graphs find application in not just the most beautiful of domains, but also the useful ones. Take heterogenous data like molecules, they are the building blocks of matter with electrons and atoms hanging in 3d space joined to their brethren via bonds and that too different kinds of - Single/Double/Covalent/Ionic. It's a very convenient and common abstraction to describe molecules then, as Graphs! With nodes representing molecules and edges representing covalent bonds.

Graphs can help us make sense of the most cluttered up data like in the case of social networks, which can be very insightful in figuring out patterns and collective behavior of people and of entities that are inter-connected.


Confused?

Alright so we have discussed a few use cases where we can use graphs to represent the data but what after that? What tasks can we perform on these collections once they are represented as graphs and how?

So let's first take a look at the tasks that we can perform on the graphs once we have represented the required data using them, and then dive into how GNNs can make the task simpler.


Graph Level Task

First of many tasks that can be performed are "graph level tasks" i.e. looking at the graphs as a whole and then predicting the property of the entire graph. For instance, predicting how a molecule smells like based on the graphical representation of the molecules that we have is a graph level task. Or to draw a rough analogy, a graph level task would be akin to trying to classify certain images of the dataset like CIFAR 10, while with text, a similar problem is sentiment analysis where we want to identify the mood or the context of a given token of sentence.


Node Level Task

A node level task is associated with trying to predict the properties of a node in a graph, for instance you can imagine a small social circle as a graph and then a node level task there would be to classify each node to be belonging to a certain class.


Edge Level Task

Another important aspect that we would like to deal with is classifying the relationship between various objects present in an image. That can fall under an edge level task. Consider a image that depicts a pitcher on a baseball field. If we consider all the entities present in the image as nodes, and then represent connections between them as edges of a graph, then one of the many relations that these edges can depict is the "action" between any two objects. For instance the pitcher and the hitter can be connected as "playing".


A suitable representation

But to perform all of these awesome tasks, we would need a representation for these graphs that would let us work with them in a mathematical setting where it's more about numbers than pictures. 'Cause even though the enamel of your teeth is harder than steel, we don't use it in construction(Ha Ha bad joke)


Alright so taking a step back, we are looking for a mathematical or a computer scienc-ish representation of graphs, so then we think about what information about a graph do we need to capture - Edges and Nodes, and which nodes are connected and by which edges.

One possibility is the good old Adjacency Matrix, so consider a example below:


In the matrix, the entry 1 indicates a connection between the corresponding nodes in the Graph. Now although this is a very nice representation, it's not very memory efficient. Add to that the fact that if the position of any of the nodes is switched, the matrix completely changes, so it would be like playing dice while using these matrices as inputs to a model.


Adjacency Lists

It's almost as if we have the sibling of an Adjacency Matrix to help us out now but this time the memory usage is very efficient. This time we use a tuple to capture what nodes are connected, so for instance from the previous graph, the nodes 2,4 are connected, so the adjacency list would contain the tuple (2,4) in it, hence the above graph can simply be represented as :

[[1,4],[2,4],[3,4]], now notice how even if we change the order, it doesn't make a difference, i.e. the list [[2,4],[1,4],[3,4]] represents the same information as the former. This nice property is called "Permutation Invariance".

Talking about a graph level task, we need efficient transfer of data between two nodes of the graph to be able to make complex predictions.

This technique or message-passing is the primary technique methodology behind all of the things that we plan to accomplish using a Graph Neural Network.

Information from each node/edge is collected and then aggregated using some function before being reapplied to the whole graph area and updating the information. But this turns out to be an issue for nodes/edges that are far apart because messages take longer to transmit even though we apply message passing multiple times across those nodes.


Graph Neural Networks

So let's talk about Graph Neural Networks then, with all the information about the graph loaded into the adjacency list(adjacency list is better as the order doesn't change the information expressed). We will be discussing about the simplest GNN architecture which use the method of message passing to do the required tasks.


But what in the wild world is message passing!????

Okay I will break it down a bit, so imagine that you treat a molecule's structure as a graph, then the information or essentially the numbers present in the nodes could tell you something about the atoms, those in the edges could tell you something about the bonds and when taken together as an aggregate, they convey something about the whole graph or the molecule. So the idea of message passing is for a node in the graph to kinda collect these numbers(of it's own and those of it's neighbors), aggregate them(via some function maybe mean or sum) and then update the other nodes about the same.


But this turns out to be an issue for nodes/edges that are far apart because messages take longer to transmit even though we apply message passing multiple times across those nodes.


Solution to the problem of message transfer

One of the possible ways in which we could tackle the issue of message passing between far away nodes, is by using global representation of a graph or something called the context vector.


The global context vector is connected to all other nodes and edges of the network and can act as a communicator between the nodes and the edges. So you can think of it as the internet that connects that two places far apart on the globe, and allows seamless exchange of information between two parts of the graph, which allows the representation to be sort of more complete and more connected in a sense.


Then the simple magic of Graph Neural Network is that this representation is passed through a them to learn the required representation. So the things learned could very likely be "what numbers in the nodes/edges allow me to represent a paracetamol molecule" etc.

An important fact here to keep in mind is that the network doesn't make any changes to the number of nodes, edges in the graph. So the amount of information required to represent the output graph is the same as that required to represent the output graph, the only difference is that the embeddings have been updated now.


Now this seems very simple but we can also have some information missing from either the nodes or the edges sometimes, and then in that case if we are supposed to apply a neural network on the node embeddings, we need to "pool" the data. So what basically happens is every node gathers information from it's surrounding neighbors via a process called pooling and aggregates it(yes you got it right, message passing) and then the neural network or the function is applied on the aggregated data.


Now whether we choose to transfer data from nodes to edges or vice versa is something that needs to be looked into because the two embeddings need not necessarily be of the same size, so it is not very obvious as to how to directly combine them. We can again use a neural network to map from one embedding to the other or maybe concatenate multiple embeddings together(Don't worry, think of embedding as simply a vector or even simpler a collection of numbers that represent something about the associated entity). What's also important is the way in which information is updated. Remember how we talked about updating the embedding of the edges and the nodes, what also matters is the order in which these are done. These design decisions among others(number of nodes, degree of each node etc.) are a bunch of factors that go into making efficient Graph Neural Networks.


But wait Om, you haven't yet touched upon the fact of how predictions are made and what on earth are these Graph Neural Networks doing with this message passing yada yada yada....


Alright so let's consider the following task: we have a bunch of drug molecules and their effectiveness listed against a disease(umm...say tuberculosis). What we can try to do is to have a model that could predict the effectiveness of newer molecules tuberculosis. So in order to leverage the power of GNNs, we would ideally need to express our molecules as graphs first and then pass them through the network.


So consider the following molecule:

Source: Me 🤡


The way we could go about representing this as a graph can vary, we can either consider individual atoms as nodes or some smaller fragments of the molecule as nodes, in which we would need to specify the index of the node that other fragments can connect to. So let's say we take a CO fragment from the molecule the vector representing the positions where you can connect to can look like [0,0](since you can attach two more atoms to the carbonyl carbon) and you can maybe have a feature vector that represents some additional information about the particular molecule. Need an example for that too??? Well C'mon >_<... okay okay I will give you one...


Consider you choose to represent features like toxicity, pH, polarity etc.. about every fragment/atom/node you have with you. So you will create a vector, that looks something like [0.5,0.4,0.9,.....] with the numbers in the vector representing those particular features about the fragment. This is the embedding....that I have been screaming about throughout this post. You can maybe try having a similar feature vector for the bonds as well, then comes in the pooling operation where you do crazy shit(nah just gather all the vectors together, put them into a matrix after applying either sum or mean function). Then pass this and the other nodes through the neural network to learn the relation between each fragment and how they affect the molecules' effectiveness against the particular disease. So in the end output of our model is single number which could either be MIC(how well it inhibits a bacteria) or some other metric.


After which you have the good old backpropagation to learn the feature vectors. Now of course there are other tricks up our sleeves which we can leverage to make these predictions better, like also utilizing the information about the entire graphs and the connections rather than just it's nodes and edges, and also what types of graphs we choose to represent the data matters. But as we have seen, Graphs in general can be complicated sometimes, and life I feel is nothing but a graph, and it's best traversed when we do it one node at a time(❁´◡`❁)



PS: If you find any errors in the above writing, I beg you to shatter my delusions at okhere21@gmail.com



 
 
 

Comments

Rated 0 out of 5 stars.
No ratings yet

Add a rating

Learn and Collaborate on:

  • Github
  • Twitter
  • Instagram

Information is everywhere, but it's meaning is created by the observer who interprets it.

                          - Naval Ravikant

bottom of page