How GNNs read a crystal structure directly, without handcrafted features, and why they outperform CNNs for property prediction.
Crystal = graph (atoms = nodes, bonds = edges)
Message passing between neighbours
CGCNN, ALIGNN, MEGNet, SchNet
Formation energy, band gap, elastic moduli
In Post 11 we saw how CNNs extract spatial patterns from 2D images of diffraction patterns. But most materials data is not an image — it is an atomic structure: a set of atoms at precise 3D coordinates, connected by bonds, embedded in a periodic crystal lattice. Squashing this rich relational information into a flat pixel grid or a hand-engineered feature vector throws away most of what matters.
Graph Neural Networks (GNNs) keep the structure intact. They represent a crystal as a graph — atoms become nodes, bonds or neighbourhoods become edges — and then learn to propagate information along those edges. After several rounds of message-passing, each node's representation reflects not only its own element but also its chemical environment, enabling highly accurate property prediction directly from structure.
A multilayer perceptron (MLP) requires a fixed-length input vector, forcing us to choose features in advance (ionic radius, electronegativity, coordination number…). A CNN needs a regular 2D or 3D grid — but crystal structures have variable numbers of atoms per unit cell and complex topology that doesn't map cleanly to a grid. GNNs accept graphs of any size and topology, making them the natural architecture for crystal data.
1. What Is a Graph? — Translating a Crystal Into Nodes and Edges
A graph G = (V, E) is simply a collection of nodes V (vertices) connected by edges E. Each node and edge can carry a feature vector. In the materials context:
- Nodes → atoms. Node features: atomic number Z, electronegativity, ionic radius, oxidation state.
- Edges → bonds or proximity pairs within a cutoff radius (typically 8 Å). Edge features: bond distance, bond angle (in ALIGNN), Gaussian-expanded distance.
- Graph-level label → the property we want to predict (formation energy eV/atom, band gap eV, bulk modulus GPa…).
🔗 Crystal → Graph Conversion (BCC Iron, 2-atom unit cell)
The Gaussian basis function (GBF) expansion of distances is an important preprocessing step: instead of storing a raw scalar d = 2.48 Å, we expand it into a vector of 20 Gaussian "bumps" centred at evenly spaced distances between 0 and 8 Å. This gives the network a smooth, differentiable representation of distance that it can easily learn from.
2. Message Passing — How Information Travels Through the Graph
The central idea of GNNs is message passing: each node collects information (messages) from its neighbours, aggregates them, and updates its own hidden representation. After K rounds of message passing, a node's representation encodes information from all nodes within K hops — its K-hop neighbourhood.
For each edge (i→j), compute a message mij from node j's features and the edge features.
Sum (or mean) all incoming messages to node i: Mi = Σj∈𝒩(i) mij
Pass aggregated message through a neural network to update node i's hidden state hi.
After K rounds, pool all node vectors into a graph-level vector → predict property.
Mi(k) = Σj ∈ 𝒩(i) mij(k) ← aggregation (sum)
hi(k+1) = φu( hi(k), Mi(k) ) ← update function
ŷ = φr( Σi hi(K) ) ← readout / pooling
φm, φu, φr are small MLPs with learnable weights.
In DFT, the total energy is determined by the electron density around each atom and its interaction with neighbours (Hohenberg–Kohn theorem). GNNs implement a learned version of this same locality principle: each atom's contribution to a property depends on its local chemical environment, captured by the hidden state hi(K) after K rounds of message passing.
3. CGCNN — The First Graph Network for Crystals
Crystal Graph Convolutional Neural Network (CGCNN), introduced by Xie & Grossman (2018), was the first GNN specifically designed for crystalline materials. It demonstrated that a single unified model could predict formation energy, band gap, Fermi energy, bulk modulus, shear modulus, and Poisson's ratio — all from the crystal graph alone, without any hand-engineered descriptors.
CGCNN architecture
Formation energy MAE ≈ 0.039 eV/atom · Band gap MAE ≈ 0.388 eV · Bulk modulus MAE ≈ 0.054 log(GPa) — all significantly better than earlier ML models using fixed fingerprints like SOAP or ACSF.
CGCNN in PyTorch Geometric — the convolution layer
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import scatter
class CGCNNConv(MessagePassing):
def __init__(self, node_dim, edge_dim):
super().__init__(aggr='add') # sum aggregation
# Message network: concat(h_i, h_j, e_ij) → message
self.msg_net = nn.Sequential(
nn.Linear(2*node_dim + edge_dim, 2*node_dim),
nn.BatchNorm1d(2*node_dim),
)
self.update_net = nn.Sequential(
nn.Linear(node_dim, node_dim),
nn.Softplus()
)
def forward(self, x, edge_index, edge_attr):
return self.propagate(edge_index, x=x, edge_attr=edge_attr)
def message(self, x_i, x_j, edge_attr):
cat = torch.cat([x_i, x_j, edge_attr], dim=-1)
m = self.msg_net(cat) # 2*node_dim
gate, feat = m.chunk(2, dim=-1) # gated mechanism
return torch.sigmoid(gate) * nn.functional.softplus(feat)
def update(self, aggr_out):
return self.update_net(aggr_out)
4. ALIGNN — Adding Bond Angles to the Graph
ALIGNN (Atomistic Line Graph Neural Network, Choudhary & DeCost 2021) extends CGCNN by a critical insight: bond angles carry more information than bond distances alone. The angle between bonds Fe–O–Fe in a perovskite directly controls the superexchange interaction and hence the magnetic ordering temperature. Ignoring angles means ignoring one of the most important structural fingerprints.
ALIGNN's solution is elegant: it builds two graphs simultaneously and runs message passing on both:
Atom graph G
Nodes = atoms, edges = bonds within cutoff. Same as CGCNN. Node features hi updated by bond messages.
Line graph L(G)
Nodes = bonds from G, edges = shared-atom bond pairs. Edge feature = bond angle θijk. Bond features eij updated by angle messages.
Coupled update
Atom graph uses updated bond features; line graph uses updated atom features. The two graphs "talk to each other" at every layer.
If G has bonds (i–j) and (j–k) — both sharing atom j — then in the line graph L(G) these two bonds become nodes connected by an edge whose feature is the bond angle ∠i–j–k. Computing this from a CIF file or POSCAR takes one call to jarvis.core.Atoms.get_all_neighbors().
eij(k+1) = φL( eij(k), Σm∈𝒩L(ij) φa(eij(k), ejk(k), θijk) )
Atom graph pass (updates atom features using bond context):
hi(k+1) = φG( hi(k), Σj∈𝒩(i) φm(hi(k), hj(k), eij(k+1)) )
Readout: ŷ = MLP( (1/N) Σi hi(K) )
ALIGNN performance vs CGCNN
| Property | CGCNN MAE | ALIGNN MAE | Improvement |
|---|---|---|---|
| Formation energy (eV/atom) | 0.039 | 0.022 | −44% |
| Band gap (eV) | 0.388 | 0.218 | −44% |
| Bulk modulus log(GPa) | 0.054 | 0.051 | −6% |
| Shear modulus log(GPa) | 0.087 | 0.078 | −10% |
5. The GNN Family — CGCNN, ALIGNN, MEGNet, SchNet at a Glance
First graph model for crystals. Bond distances as edge features. Gated message passing. Fast, widely used as a baseline. PyTorch Geometric implementation in JARVIS-Tools.
Adds bond angles via line graph. Significantly better for properties sensitive to geometry (band gap, elastic). Native in JARVIS-Tools. Best-in-class for most JARVIS benchmarks.
Atom + bond + global state updates (3-body). Handles molecules and crystals. TensorFlow/Keras. Training on MP, used by DeepMind's GNoME project.
Uses continuous filter banks instead of GBFs for distances. Rotationally invariant. Excellent for molecular dynamics and atomistic potentials (NequIP, MACE evolved from it).
6. From CIF File to Graph — A Practical Walkthrough
Let us build a crystal graph for SrTiO₃ perovskite starting from a CIF file, using the JARVIS-Tools library (which also hosts the ALIGNN model).
from jarvis.core.graphs import Graph
import numpy as np
# Load structure from CIF (or POSCAR, VASP, etc.)
atoms = Atoms.from_cif('SrTiO3.cif')
print(f"Atoms: {atoms.elements}") # ['Sr', 'Ti', 'O', 'O', 'O']
# Build crystal graph
g = Graph.atom_dgl_multigraph(
atoms,
cutoff=8.0, # neighbour cutoff in Å
max_neighbours=12, # per atom
atom_features='cgcnn' # 92-dim one-hot + properties
)
print(f"Nodes: {g.num_nodes()}") # 5 atoms
print(f"Edges: {g.num_edges()}") # ~60 bonds (periodic images)
print(f"Node feat dim: {g.ndata['atom_features'].shape[1]}") # 92
print(f"Edge feat dim: {g.edata['r'].shape[1]}") # 3 (displacement vector)
Running ALIGNN predictions with pre-trained weights
from jarvis.core.atoms import Atoms
# Download pre-trained ALIGNN model for formation energy
model = get_figshare_model(model_name='jv_formation_energy_peratom_alignn')
# Predict for SrTiO3
atoms = Atoms.from_cif('SrTiO3.cif')
result = model.predict_structure(atoms)
print(f"Predicted Eform: {result:.4f} eV/atom") # ≈ −3.52 eV/atom
# Run on a list of structures
from jarvis.db.figshare import data as jdata
dft_3d = jdata('dft_3d')[:100] # first 100 JARVIS-DFT entries
predictions = [
model.predict_structure(Atoms.from_dict(d['atoms']))
for d in dft_3d
]
7. Training Your Own ALIGNN Model — Formation Energy from JARVIS-DFT
from alignn.train import train_dgl_multihead
from alignn.config import TrainingConfig
config = TrainingConfig(
random_seed=123,
epochs=100,
n_train=44578,
n_val=5572,
n_test=5572,
target='formation_energy_peratom',
batch_size=64,
alignn_layers=4, # number of ALIGNN+GCN pairs
gcn_layers=4,
atom_input_features=92,
edge_input_features=80, # 80-dim GBF for bond distances
triplet_input_features=40, # 40-dim GBF for bond angles
hidden_features=256,
output_features=1,
learning_rate=1e-3,
criterion='mse',
optimizer='adamw',
scheduler='onecycle',
)
train_dgl_multihead(config) # trains, saves best checkpoint
Training ALIGNN from scratch on JARVIS-DFT (~55,000 structures) requires a GPU and ~12 h of training. For smaller datasets (< 2,000 structures) fine-tune a pretrained model: freeze the graph convolution layers, retrain only the output head at a low learning rate (1e-4). This is the materials science equivalent of transfer learning from ImageNet.
9. GNN vs CNN vs MLP — When to Use Each
| Property | MLP | CNN | GNN |
|---|---|---|---|
| Input type | Fixed-length feature vector | 2D/3D image (regular grid) | Graph (variable atoms/bonds) |
| Spatial awareness | ✗ None | ✓ 2D locality only | ✓ Full topology |
| Handles periodic BC | ✗ | ✗ (with tricks) | ✓ (supercell edges) |
| Equivariant to rotation | ✗ | Approx. with augmentation | ✓ (invariant representations) |
| Best materials use | Composition-only models, fast screening | Phase ID from diffraction/SEM images | Structure → property (Eform, Eg, moduli) |
| Data need | ~100–1,000 | ~1,000–10,000 images | ~5,000–100,000 structures |
| Key libraries | scikit-learn, PyTorch | torchvision, PyTorch | PyTorch Geometric, DGL, JARVIS |
Quick Check
1. In CGCNN, what does an edge in the crystal graph represent, and what features does it carry?
2. ALIGNN improves over CGCNN mainly because it explicitly encodes which structural information that CGCNN misses?
3. You have 800 DFT-computed formation energies for a family of novel perovskites not in any public database. The best GNN strategy is: