Graph Neural Networks (GNN)

Neural networks that operate on graph-structured data. Nodes and edges carry features; each layer updates a node's representation by aggregating messages from its neighbours. Used for molecules, social networks, knowledge graphs, and any data with non-Euclidean structure.

TL;DR. A GNN is a stack of message-passing layers. Each layer: for every node, aggregate features from its neighbours (sum / mean / attention), combine with the node's own feature, and pass through an MLP. GCN uses a normalised adjacency for the aggregation; GAT learns attention weights over neighbours. For graph-level tasks (molecule property prediction), pool node features at the end. USAAIO and IOAI use GNNs on Cora-style node classification, OGB molecular benchmarks, and any tabular dataset with a natural graph structure (e.g. social or citation networks).

1. The intuition

A CNN exploits a grid: every pixel has a fixed neighbourhood (the 3x3 window) and weights are shared across positions. A GNN generalises this to arbitrary graphs: each node has a variable-size neighbourhood, and the same learned function is applied at every node.

One layer of a GNN propagates information one hop. After L layers, each node's representation has seen everything within L hops. This is exactly analogous to receptive field growth in CNNs. The aggregation must be permutation-invariant over neighbours (a node has no inherent ordering of its neighbours), so we use sum, mean, max, or attention — never concatenation in a fixed order.

Two failure modes shape design: over-smoothing (too many layers makes all node representations converge to the same value) and over-squashing (information from distant nodes gets compressed through narrow bottlenecks). Practical GNNs are 2-5 layers deep with residual connections.

2. The math

Message passing — the general form

Let h_v^(l) be node v's feature at layer l, and N(v) its neighbour set. A message-passing layer is:

m_v^(l+1) = AGG_{u in N(v)} ( MSG( h_v^(l), h_u^(l), e_{uv} ) ) h_v^(l+1) = UPDATE( h_v^(l), m_v^(l+1) )

Where AGG is permutation-invariant (sum / mean / max / attention), MSG is a small MLP, and UPDATE usually combines with a residual. Every common GNN — GCN, GAT, GraphSAGE, GIN — is an instance of this template.

GCN — Kipf & Welling

Define the adjacency A with self-loops A_hat = A + I, and the degree matrix D_hat of A_hat. The symmetrically normalised adjacency is:

A_tilde = D_hat^(-1/2) * A_hat * D_hat^(-1/2)

A GCN layer in matrix form is then:

H^(l+1) = sigma( A_tilde * H^(l) * W^(l) )

where H^(l) in R^(N x d_l) stacks all node features, W^(l) is a learned weight, and sigma is ReLU. The normalisation D^(-1/2) A D^(-1/2) is critical: it prevents high-degree nodes from dominating and keeps activations from blowing up across layers.

GAT — graph attention

Replace the fixed A_tilde weights with learned attention coefficients. For each edge (u, v):

e_{vu} = LeakyReLU( a^T [W h_v ; W h_u] ) alpha_{vu} = softmax_u( e_{vu} ) over u in N(v) ∪ {v} h_v' = sigma( sum_{u in N(v) ∪ {v}} alpha_{vu} * W h_u )

Multi-head GAT runs K independent attention heads and concatenates (or averages, on the last layer). The structure echoes transformer self-attention but restricted to graph neighbours — equivalent to attention with a mask set by the adjacency.

Readout for graph-level tasks

For node-level tasks (citation classification) the per-node output is used directly. For graph-level tasks (molecule property prediction) all node features are pooled into a single graph embedding:

h_G = READOUT( { h_v^(L) : v in V } ) e.g. mean, sum, max, or attention pool

Then a final MLP maps h_G to the target. Sum pooling preserves graph size information; mean pooling does not.

3. PyTorch reference implementation

import torch
import torch.nn as nn
import torch.nn.functional as F


def normalise_adjacency(A: torch.Tensor) -> torch.Tensor:
    """Compute D_hat^(-1/2) (A + I) D_hat^(-1/2) for a dense adjacency."""
    N = A.size(0)
    A_hat = A + torch.eye(N, device=A.device)
    deg   = A_hat.sum(dim=1)
    d_inv_sqrt = torch.diag(deg.pow(-0.5))
    return d_inv_sqrt @ A_hat @ d_inv_sqrt


class GCNLayer(nn.Module):
    """One GCN layer: H' = sigma( A_tilde H W ). A_tilde is precomputed once."""
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.W = nn.Linear(in_dim, out_dim, bias=True)

    def forward(self, H, A_tilde):
        return A_tilde @ self.W(H)              # ReLU applied outside


class GCN(nn.Module):
    """Two-layer GCN suitable for Cora-style node classification."""
    def __init__(self, in_dim, hidden, num_classes, dropout=0.5):
        super().__init__()
        self.g1 = GCNLayer(in_dim, hidden)
        self.g2 = GCNLayer(hidden, num_classes)
        self.dropout = dropout

    def forward(self, X, A_tilde):
        h = F.relu(self.g1(X, A_tilde))
        h = F.dropout(h, p=self.dropout, training=self.training)
        return self.g2(h, A_tilde)              # logits


def train_node_classifier(model, X, A_tilde, y, train_mask, n_steps=200, lr=1e-2):
    optim = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=5e-4)
    for step in range(n_steps):
        model.train(True)
        logits = model(X, A_tilde)
        loss   = F.cross_entropy(logits[train_mask], y[train_mask])
        optim.zero_grad()
        loss.backward()
        optim.step()
        if step % 50 == 0:
            print(step, float(loss))


@torch.no_grad()
def accuracy(model, X, A_tilde, y, mask):
    model.train(False)                          # equivalent to standard inference mode
    pred = model(X, A_tilde).argmax(dim=-1)
    return float((pred[mask] == y[mask]).float().mean())


if __name__ == "__main__":
    torch.manual_seed(0)
    # Tiny Cora-like setup: 50 nodes, 7 classes, random features.
    N, F_in, C = 50, 16, 7
    X = torch.randn(N, F_in)
    y = torch.randint(0, C, (N,))
    # Build a random sparse adjacency (symmetric, no self loops).
    A = (torch.rand(N, N) < 0.08).float()
    A = ((A + A.T) > 0).float()
    A.fill_diagonal_(0)
    A_tilde = normalise_adjacency(A)
    # Train / val masks.
    perm = torch.randperm(N)
    train_mask = torch.zeros(N, dtype=torch.bool); train_mask[perm[:30]] = True
    val_mask   = torch.zeros(N, dtype=torch.bool); val_mask[perm[30:]]  = True

    model = GCN(F_in, hidden=32, num_classes=C)
    train_node_classifier(model, X, A_tilde, y, train_mask, n_steps=100)
    print("val acc:", accuracy(model, X, A_tilde, y, val_mask))

model.train(False) is the standard inference switch (same effect as the short-form call on the module). Real GNN code uses sparse adjacency (PyTorch Geometric or DGL); dense is shown here for clarity and only works up to a few thousand nodes.

4. Common USAAIO / IOAI applications

5. Drills

D1 · Why normalise the adjacency?

What goes wrong if you use raw A instead of D^(-1/2) A_hat D^(-1/2)?

Solution

(1) Without self-loops (A_hat = A + I) a node never sees itself, so its own features vanish after one layer. (2) Without degree normalisation, high-degree nodes accumulate huge messages and their activations explode across layers, while low-degree nodes lag behind. Symmetric normalisation rebalances.

D2 · Receptive field after L layers

A GCN with 3 layers — how far does information flow?

Solution

3 hops. Each layer propagates messages from immediate neighbours; L layers stack to L-hop receptive field. This is why deep GNNs over-smooth: every node ends up seeing nearly the whole graph and representations collapse to a shared mean.

D3 · GCN vs GAT on a hub node

A user-account node has 10 000 neighbours, most of which are bots. GCN treats all neighbours equally (after degree normalisation). What does GAT do differently?

Solution

GAT learns per-edge attention coefficients alpha_{vu}; it can down-weight the bot neighbours and concentrate on the few informative edges. The cost is more parameters and slightly more compute per layer.

D4 · Sum vs mean pooling on molecules

You're predicting solubility from a molecule graph. Sum or mean readout?

Solution

Sum preserves molecule size — a bigger molecule produces a bigger embedding, and solubility correlates with size. Mean discards that signal. For size-invariant targets (some shape descriptors) mean is fine; for extensive properties (mass, energy, solubility) sum is usually better. Attention pooling can learn either.

D5 · Debugging over-smoothing

Your 6-layer GCN trains to low loss but validation accuracy is worse than a 2-layer GCN. What's happening and what fixes?

Solution

Over-smoothing — node representations converge to a graph-wide constant. Fixes: reduce depth to 2-3 layers; add residual connections (h^(l+1) = h^(l+1) + h^(l)); use PairNorm or DropEdge; or switch to GAT / GraphSAGE which mitigate the problem.

Next step

GAT generalises the attention mechanism from Transformers to graph structure; revisit that page for multi-head attention details. For short-answer drilling on message-passing math and over-smoothing, head to Round 2 theory, then practice end-to-end on a molecule task in Problems.