“GRAPH ATTENTION NETWORKS”, it is awesome as it has attention mechanisms

Today, I would like to introduce “GRAPH ATTENTION NETWORKS” (GAT) (1). I like it very much as it has attention mechanisms. Let us see how it works.

  1. which nodes should we pay more attention to?

As I said before, GNN can be updated by taking information from neighbors. But you may be wondering which information is more important than others. In other words, which nodes should we pay more attention to? As the chart shows, the information from some nodes is more important than others. In the chart, rhe thicker red arrows from sender nodes to the receiver node is, the more attention GAT should pay to the node. But how can we know to which nodes should be paid more attention?  

2. Attention mechanism

Some of you may not know about “attention mechanisms”, so I will explain it in detail. It is getting popular when the Natural Language processing (NLP) model called “transformer” introduces this mechanism in 2017. In this NLP model can understand what words are more important than others when the model considers one specific word in sentences. GAT introduces the same mechanism to understand “which nodes GAT should pay more attention to than other nodes when the information is gathered from neighbors?”. The chart below explains how the attention mechanism works. They are taken from the original research paper of GAT (1).

In order to understand “what nodes GAT should pay more attention to”, attention weights (red arrow) are needed. The bigger these weights are, the more attention GAT should pay. To calculate attention weights, firstly features of sender node(green arrow) and receiver node(blue arrow) are linearly transformed and concatenated. eij is calculated by a single layer neural network (formula 1). This is called “self-attention” and eij is called “attention coefficient”. Once attention coefficients of all sender nodes are obtained, we put them into the softmax function to normalize them (formular 2). Then “attention weights” of aij can be obtained (right illustration). When you want to know more, please check formula 3. Finally the receiver node can be updated base on “attention weights” aij (formula 4).

3. multi-head attention

GAT introduces multi-head attention. It means that GAT has several attention mechanisms (right illustration). K attention mechanisms execute the transformation of formula 4, and then their features are concatenated (formula 5). When we perform multi-head attention on the final layer of the network, instead of concatenation, GAT uses average of results from each attention head and delays applying a softmax for classification task (formula 6).

Hope you enjoy the article. I like GAT as it is easy to use and more accurate than other GNNs I explained before. I will update my article soon. Stay tuned!

(1) GRAPH ATTENTION NETWORKS, Petar Veličković, Guillem Cucurull, Arantxa Casanova, Adriana Romero, Pietro Liò, Yoshua Bengio

Notice: ToshiStats Co., Ltd. and I do not accept any responsibility or liability for loss or damage occasioned to any person or property through using materials, instructions, methods, algorithms or ideas contained herein, or acting or refraining from acting as a result of such use. ToshiStats Co., Ltd. and I expressly disclaim all implied warranties, including merchantability or fitness for any particular purpose. There will be no duty on ToshiStats Co., Ltd. and me to correct any errors or defects in the codes and the software.

Node Classification with Graph Neural Networks by PyG !

Last time, I introduced GCN (“GRAPH CONVOLUTIONAL NETWORKS”) in theory. Today, I would like to solve the problem with GCN. Before doing that, I choose the best framework for graph neural networks. This is “PyG”(1).

  1. PyG (PyTorch Geometric)

Let me look at the explanations of PyG in its official document.

“PyG is a library built upon PyTorch to easily write and train Graph Neural Networks for a wide range of applications related to structured data. PyG is both friendly to machine learning researchers and first-time users of machine learning toolkits.”

I thinks it is the best for beginners because

  • It is based on Pytorch, which is written in python and widely used in deep learning tasks.
  • There are well-written documents and notebook tutorials.
  • It has many “out of the box” functions so we can start experiments with GNN immediately.

2. Prepare graph data

Our task is called “node classification”. It means that each node has its class (Ex: age, rating, income, fail or success, default or not default, purchase or no-purchase, cured or not cured, whatever you like). We would like to predict “what the class of each node is” based on the graph data. 

Let me introduce “Cora" dataset(2), citation network. Each node represents a document. It has a 1433-dimensional feature and belongs to one of the seven classes. Each edge means a citation from the document to another. Our task is to predict the class of each document. Let us visualize each node before training our GCN as a dot below. We can see seven colours as there are seven colours in this graph.

3. GCN implementaion with PyG

Let us train GCN to analyse the graph data now. I explained how GCN works before so when you missed it, check it up.

This is a GCN implementaion with PyG. PyG has GCN in it. So all we have to do is 1. import GCNConv, 2. create a class by using GCNComv. That’s it! It looks easy if you are already familiar with Pytorch. When you want to run the whole notebook, it is available for you in the PyG official document. Here the link is.

Here is the result of our analysis. It looks good as nodes are classified into seven classes.

GCN can be applied to many tasks as it has a simple structure. Why don’t you try it by yourself with PyG today?

(1) PyG (PyTorch Geometric)  https://www.pyg.org/

(2) Revisiting Semi-Supervised Learning with Graph Embeddings Zhilin Yang, William W. Cohen, Ruslan Salakhutdinov, May 2016

Notice: ToshiStats Co., Ltd. and I do not accept any responsibility or liability for loss or damage occasioned to any person or property through using materials, instructions, methods, algorithms or ideas contained herein, or acting or refraining from acting as a result of such use. ToshiStats Co., Ltd. and I expressly disclaim all implied warranties, including merchantability or fitness for any particular purpose. There will be no duty on ToshiStats Co., Ltd. and me to correct any errors or defects in the codes and the software.