Graph Neural Networks (GNN)
Neural networks that operate on graph-structured data. Nodes and edges carry features; each layer updates a node's representation by aggregating messages from its neighbours. Used for molecules, social networks, knowledge graphs, and any data with non-Euclidean structure.
1. The intuition
A CNN exploits a grid: every pixel has a fixed neighbourhood (the 3x3 window) and weights are shared across positions. A GNN generalises this to arbitrary graphs: each node has a variable-size neighbourhood, and the same learned function is applied at every node.
One layer of a GNN propagates information one hop. After L layers, each
node's representation has seen everything within L hops. This is exactly
analogous to receptive field growth in CNNs. The aggregation must be
permutation-invariant over neighbours (a node has no inherent ordering of its
neighbours), so we use sum, mean, max, or attention — never concatenation in a fixed
order.
Two failure modes shape design: over-smoothing (too many layers makes all node representations converge to the same value) and over-squashing (information from distant nodes gets compressed through narrow bottlenecks). Practical GNNs are 2-5 layers deep with residual connections.
2. The math
Message passing — the general form
Let h_v^(l) be node v's feature at layer l, and
N(v) its neighbour set. A message-passing layer is:
Where AGG is permutation-invariant (sum / mean / max / attention),
MSG is a small MLP, and UPDATE usually combines with a
residual. Every common GNN — GCN, GAT, GraphSAGE, GIN — is an instance of this
template.
GCN — Kipf & Welling
Define the adjacency A with self-loops A_hat = A + I, and
the degree matrix D_hat of A_hat. The
symmetrically normalised adjacency is:
A GCN layer in matrix form is then:
where H^(l) in R^(N x d_l) stacks all node features, W^(l) is
a learned weight, and sigma is ReLU. The normalisation
D^(-1/2) A D^(-1/2) is critical: it prevents high-degree nodes from
dominating and keeps activations from blowing up across layers.
GAT — graph attention
Replace the fixed A_tilde weights with learned attention coefficients.
For each edge (u, v):
Multi-head GAT runs K independent attention heads and concatenates (or
averages, on the last layer). The structure echoes transformer self-attention but
restricted to graph neighbours — equivalent to attention with a mask set by the
adjacency.
Readout for graph-level tasks
For node-level tasks (citation classification) the per-node output is used directly. For graph-level tasks (molecule property prediction) all node features are pooled into a single graph embedding:
Then a final MLP maps h_G to the target. Sum pooling preserves graph size
information; mean pooling does not.
3. PyTorch reference implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
def normalise_adjacency(A: torch.Tensor) -> torch.Tensor:
"""Compute D_hat^(-1/2) (A + I) D_hat^(-1/2) for a dense adjacency."""
N = A.size(0)
A_hat = A + torch.eye(N, device=A.device)
deg = A_hat.sum(dim=1)
d_inv_sqrt = torch.diag(deg.pow(-0.5))
return d_inv_sqrt @ A_hat @ d_inv_sqrt
class GCNLayer(nn.Module):
"""One GCN layer: H' = sigma( A_tilde H W ). A_tilde is precomputed once."""
def __init__(self, in_dim, out_dim):
super().__init__()
self.W = nn.Linear(in_dim, out_dim, bias=True)
def forward(self, H, A_tilde):
return A_tilde @ self.W(H) # ReLU applied outside
class GCN(nn.Module):
"""Two-layer GCN suitable for Cora-style node classification."""
def __init__(self, in_dim, hidden, num_classes, dropout=0.5):
super().__init__()
self.g1 = GCNLayer(in_dim, hidden)
self.g2 = GCNLayer(hidden, num_classes)
self.dropout = dropout
def forward(self, X, A_tilde):
h = F.relu(self.g1(X, A_tilde))
h = F.dropout(h, p=self.dropout, training=self.training)
return self.g2(h, A_tilde) # logits
def train_node_classifier(model, X, A_tilde, y, train_mask, n_steps=200, lr=1e-2):
optim = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=5e-4)
for step in range(n_steps):
model.train(True)
logits = model(X, A_tilde)
loss = F.cross_entropy(logits[train_mask], y[train_mask])
optim.zero_grad()
loss.backward()
optim.step()
if step % 50 == 0:
print(step, float(loss))
@torch.no_grad()
def accuracy(model, X, A_tilde, y, mask):
model.train(False) # equivalent to standard inference mode
pred = model(X, A_tilde).argmax(dim=-1)
return float((pred[mask] == y[mask]).float().mean())
if __name__ == "__main__":
torch.manual_seed(0)
# Tiny Cora-like setup: 50 nodes, 7 classes, random features.
N, F_in, C = 50, 16, 7
X = torch.randn(N, F_in)
y = torch.randint(0, C, (N,))
# Build a random sparse adjacency (symmetric, no self loops).
A = (torch.rand(N, N) < 0.08).float()
A = ((A + A.T) > 0).float()
A.fill_diagonal_(0)
A_tilde = normalise_adjacency(A)
# Train / val masks.
perm = torch.randperm(N)
train_mask = torch.zeros(N, dtype=torch.bool); train_mask[perm[:30]] = True
val_mask = torch.zeros(N, dtype=torch.bool); val_mask[perm[30:]] = True
model = GCN(F_in, hidden=32, num_classes=C)
train_node_classifier(model, X, A_tilde, y, train_mask, n_steps=100)
print("val acc:", accuracy(model, X, A_tilde, y, val_mask))
model.train(False) is the standard inference switch (same effect as the
short-form call on the module). Real GNN code uses sparse adjacency (PyTorch Geometric
or DGL); dense is shown here for clarity and only works up to a few thousand nodes.
4. Common USAAIO / IOAI applications
- Molecule property prediction — nodes = atoms, edges = bonds, features = atomic number / charge / hybridisation. Targets: solubility, toxicity, binding affinity. OGB datasets (ogbg-molhiv, ogbg-molpcba) are the standard.
- Citation network classification — Cora, CiteSeer, PubMed: predict paper category from text features + citation graph.
- Social network / recommendation — friend prediction, item recommendation as link prediction over a user-item bipartite graph.
- Combinatorial optimisation — learn heuristics for TSP / graph colouring with GNN policies.
- Knowledge graph completion — predict missing triples; relational GNN variants (R-GCN) handle multiple edge types.
5. Drills
D1 · Why normalise the adjacency?
What goes wrong if you use raw A instead of
D^(-1/2) A_hat D^(-1/2)?
Solution
(1) Without self-loops (A_hat = A + I) a node never sees itself, so
its own features vanish after one layer. (2) Without degree normalisation,
high-degree nodes accumulate huge messages and their activations explode across
layers, while low-degree nodes lag behind. Symmetric normalisation rebalances.
D2 · Receptive field after L layers
A GCN with 3 layers — how far does information flow?
Solution
3 hops. Each layer propagates messages from immediate neighbours; L
layers stack to L-hop receptive field. This is why deep GNNs over-smooth:
every node ends up seeing nearly the whole graph and representations collapse to a
shared mean.
D3 · GCN vs GAT on a hub node
A user-account node has 10 000 neighbours, most of which are bots. GCN treats all neighbours equally (after degree normalisation). What does GAT do differently?
Solution
GAT learns per-edge attention coefficients alpha_{vu}; it can
down-weight the bot neighbours and concentrate on the few informative edges. The
cost is more parameters and slightly more compute per layer.
D4 · Sum vs mean pooling on molecules
You're predicting solubility from a molecule graph. Sum or mean readout?
Solution
Sum preserves molecule size — a bigger molecule produces a bigger embedding, and solubility correlates with size. Mean discards that signal. For size-invariant targets (some shape descriptors) mean is fine; for extensive properties (mass, energy, solubility) sum is usually better. Attention pooling can learn either.
D5 · Debugging over-smoothing
Your 6-layer GCN trains to low loss but validation accuracy is worse than a 2-layer GCN. What's happening and what fixes?
Solution
Over-smoothing — node representations converge to a graph-wide constant. Fixes:
reduce depth to 2-3 layers; add residual connections (h^(l+1) = h^(l+1) + h^(l));
use PairNorm or DropEdge; or switch to GAT / GraphSAGE which mitigate the problem.
Next step
GAT generalises the attention mechanism from Transformers to graph structure; revisit that page for multi-head attention details. For short-answer drilling on message-passing math and over-smoothing, head to Round 2 theory, then practice end-to-end on a molecule task in Problems.