Header Ads Widget

AI & Machine Learning for Materials Sciences

Last Posts

10/recent/ticker-posts

Post 12: Graph Neural Networks — ALIGNN & CGCNN Explained for Beginners

How GNNs read a crystal structure directly, without handcrafted features, and why they outperform CNNs for property prediction.

Module 3
🔗
Key Abstraction

Crystal = graph (atoms = nodes, bonds = edges)

📨
Core Operation

Message passing between neighbours

🏗️
Architectures

CGCNN, ALIGNN, MEGNet, SchNet

🎯
Materials Target

Formation energy, band gap, elastic moduli

In Post 11 we saw how CNNs extract spatial patterns from 2D images of diffraction patterns. But most materials data is not an image — it is an atomic structure: a set of atoms at precise 3D coordinates, connected by bonds, embedded in a periodic crystal lattice. Squashing this rich relational information into a flat pixel grid or a hand-engineered feature vector throws away most of what matters.

Graph Neural Networks (GNNs) keep the structure intact. They represent a crystal as a graph — atoms become nodes, bonds or neighbourhoods become edges — and then learn to propagate information along those edges. After several rounds of message-passing, each node's representation reflects not only its own element but also its chemical environment, enabling highly accurate property prediction directly from structure.

🔬
Why not use a CNN or MLP on crystal data?

A multilayer perceptron (MLP) requires a fixed-length input vector, forcing us to choose features in advance (ionic radius, electronegativity, coordination number…). A CNN needs a regular 2D or 3D grid — but crystal structures have variable numbers of atoms per unit cell and complex topology that doesn't map cleanly to a grid. GNNs accept graphs of any size and topology, making them the natural architecture for crystal data.

1. What Is a Graph? — Translating a Crystal Into Nodes and Edges

A graph G = (V, E) is simply a collection of nodes V (vertices) connected by edges E. Each node and edge can carry a feature vector. In the materials context:

  • Nodes → atoms. Node features: atomic number Z, electronegativity, ionic radius, oxidation state.
  • Edges → bonds or proximity pairs within a cutoff radius (typically 8 Å). Edge features: bond distance, bond angle (in ALIGNN), Gaussian-expanded distance.
  • Graph-level label → the property we want to predict (formation energy eV/atom, band gap eV, bulk modulus GPa…).

🔗 Crystal → Graph Conversion (BCC Iron, 2-atom unit cell)

Fe Z=26 Fe Z=26 Fe Z=26 Fe Z=26 2.48 Å 2.48 Å 2.48 Å 2.48 Å 2.87 Å Node feature Z = 26 χ = 1.83 r = 1.26 Å Edge feature d = 2.48 Å Δ[0..20] GBF (Gaussian basis)

The Gaussian basis function (GBF) expansion of distances is an important preprocessing step: instead of storing a raw scalar d = 2.48 Å, we expand it into a vector of 20 Gaussian "bumps" centred at evenly spaced distances between 0 and 8 Å. This gives the network a smooth, differentiable representation of distance that it can easily learn from.

2. Message Passing — How Information Travels Through the Graph

The central idea of GNNs is message passing: each node collects information (messages) from its neighbours, aggregates them, and updates its own hidden representation. After K rounds of message passing, a node's representation encodes information from all nodes within K hops — its K-hop neighbourhood.

Step 1
Compute messages

For each edge (i→j), compute a message mij from node j's features and the edge features.

Step 2
Aggregate

Sum (or mean) all incoming messages to node i: Mi = Σj∈𝒩(i) mij

Step 3
Update

Pass aggregated message through a neural network to update node i's hidden state hi.

Step 4
Readout

After K rounds, pool all node vectors into a graph-level vector → predict property.

📨 General Message-Passing Update Rule (MPNN framework)
mij(k) = φm( hi(k), hj(k), eij ) ← message function

Mi(k) = Σj ∈ 𝒩(i) mij(k) ← aggregation (sum)

hi(k+1) = φu( hi(k), Mi(k) ) ← update function

ŷ = φr( Σi hi(K) ) ← readout / pooling

φm, φu, φr are small MLPs with learnable weights.
💡
Analogy with DFT

In DFT, the total energy is determined by the electron density around each atom and its interaction with neighbours (Hohenberg–Kohn theorem). GNNs implement a learned version of this same locality principle: each atom's contribution to a property depends on its local chemical environment, captured by the hidden state hi(K) after K rounds of message passing.

3. CGCNN — The First Graph Network for Crystals

Crystal Graph Convolutional Neural Network (CGCNN), introduced by Xie & Grossman (2018), was the first GNN specifically designed for crystalline materials. It demonstrated that a single unified model could predict formation energy, band gap, Fermi energy, bulk modulus, shear modulus, and Poisson's ratio — all from the crystal graph alone, without any hand-engineered descriptors.

CGCNN architecture

  • Input graph: atoms as nodes (one-hot encoded element + 8 atom features), bonds as edges (GBF-expanded distance, 41-dimensional vector).
  • Graph convolution layers (3–10): each applies the message-passing update. The message function φm concatenates (hi, hj, eij) and passes through a fully connected layer with sigmoid gate + softplus activation.
  • Pooling: mean over all atom-level representations to produce a fixed-length graph vector.
  • Output head: two fully connected layers → scalar property prediction.
  • 📊
    CGCNN performance benchmarks (Materials Project, ~28,000 structures)

    Formation energy MAE ≈ 0.039 eV/atom · Band gap MAE ≈ 0.388 eV · Bulk modulus MAE ≈ 0.054 log(GPa) — all significantly better than earlier ML models using fixed fingerprints like SOAP or ACSF.

    CGCNN in PyTorch Geometric — the convolution layer

    import torch, torch.nn as nn
    from torch_geometric.nn import MessagePassing
    from torch_geometric.utils import scatter

    class CGCNNConv(MessagePassing):
        def __init__(self, node_dim, edge_dim):
            super().__init__(aggr='add') # sum aggregation
            # Message network: concat(h_i, h_j, e_ij) → message
            self.msg_net = nn.Sequential(
                nn.Linear(2*node_dim + edge_dim, 2*node_dim),
                nn.BatchNorm1d(2*node_dim),
            )
            self.update_net = nn.Sequential(
                nn.Linear(node_dim, node_dim),
                nn.Softplus()
            )

        def forward(self, x, edge_index, edge_attr):
            return self.propagate(edge_index, x=x, edge_attr=edge_attr)

        def message(self, x_i, x_j, edge_attr):
            cat = torch.cat([x_i, x_j, edge_attr], dim=-1)
            m = self.msg_net(cat) # 2*node_dim
            gate, feat = m.chunk(2, dim=-1) # gated mechanism
            return torch.sigmoid(gate) * nn.functional.softplus(feat)

        def update(self, aggr_out):
            return self.update_net(aggr_out)

    4. ALIGNN — Adding Bond Angles to the Graph

    ALIGNN (Atomistic Line Graph Neural Network, Choudhary & DeCost 2021) extends CGCNN by a critical insight: bond angles carry more information than bond distances alone. The angle between bonds Fe–O–Fe in a perovskite directly controls the superexchange interaction and hence the magnetic ordering temperature. Ignoring angles means ignoring one of the most important structural fingerprints.

    ALIGNN's solution is elegant: it builds two graphs simultaneously and runs message passing on both:

    🔵

    Atom graph G

    Nodes = atoms, edges = bonds within cutoff. Same as CGCNN. Node features hi updated by bond messages.

    🔶

    Line graph L(G)

    Nodes = bonds from G, edges = shared-atom bond pairs. Edge feature = bond angle θijk. Bond features eij updated by angle messages.

    🔄

    Coupled update

    Atom graph uses updated bond features; line graph uses updated atom features. The two graphs "talk to each other" at every layer.

    📐
    What is a line graph?

    If G has bonds (i–j) and (j–k) — both sharing atom j — then in the line graph L(G) these two bonds become nodes connected by an edge whose feature is the bond angle ∠i–j–k. Computing this from a CIF file or POSCAR takes one call to jarvis.core.Atoms.get_all_neighbors().

    📐 ALIGNN Update Equations (simplified)
    Line graph pass (updates bond features using angle context):
    eij(k+1) = φL( eij(k), Σm∈𝒩L(ij) φa(eij(k), ejk(k), θijk) )

    Atom graph pass (updates atom features using bond context):
    hi(k+1) = φG( hi(k), Σj∈𝒩(i) φm(hi(k), hj(k), eij(k+1)) )

    Readout: ŷ = MLP( (1/N) Σi hi(K) )

    ALIGNN performance vs CGCNN

    PropertyCGCNN MAEALIGNN MAEImprovement
    Formation energy (eV/atom) 0.039 0.022 −44%
    Band gap (eV) 0.388 0.218 −44%
    Bulk modulus log(GPa) 0.054 0.051 −6%
    Shear modulus log(GPa) 0.087 0.078 −10%

    5. The GNN Family — CGCNN, ALIGNN, MEGNet, SchNet at a Glance

    CGCNN
    Crystal Graph Convolutional NN

    First graph model for crystals. Bond distances as edge features. Gated message passing. Fast, widely used as a baseline. PyTorch Geometric implementation in JARVIS-Tools.

    ALIGNN
    Atomistic Line Graph NN

    Adds bond angles via line graph. Significantly better for properties sensitive to geometry (band gap, elastic). Native in JARVIS-Tools. Best-in-class for most JARVIS benchmarks.

    MEGNet
    MatErials Graph Network

    Atom + bond + global state updates (3-body). Handles molecules and crystals. TensorFlow/Keras. Training on MP, used by DeepMind's GNoME project.

    SchNet
    Continuous-filter CNN for molecules

    Uses continuous filter banks instead of GBFs for distances. Rotationally invariant. Excellent for molecular dynamics and atomistic potentials (NequIP, MACE evolved from it).

    6. From CIF File to Graph — A Practical Walkthrough

    Let us build a crystal graph for SrTiO₃ perovskite starting from a CIF file, using the JARVIS-Tools library (which also hosts the ALIGNN model).

    from jarvis.core.atoms import Atoms
    from jarvis.core.graphs import Graph
    import numpy as np

    # Load structure from CIF (or POSCAR, VASP, etc.)
    atoms = Atoms.from_cif('SrTiO3.cif')
    print(f"Atoms: {atoms.elements}") # ['Sr', 'Ti', 'O', 'O', 'O']

    # Build crystal graph
    g = Graph.atom_dgl_multigraph(
        atoms,
        cutoff=8.0, # neighbour cutoff in Å
        max_neighbours=12, # per atom
        atom_features='cgcnn' # 92-dim one-hot + properties
    )

    print(f"Nodes: {g.num_nodes()}") # 5 atoms
    print(f"Edges: {g.num_edges()}") # ~60 bonds (periodic images)
    print(f"Node feat dim: {g.ndata['atom_features'].shape[1]}") # 92
    print(f"Edge feat dim: {g.edata['r'].shape[1]}") # 3 (displacement vector)

    Running ALIGNN predictions with pre-trained weights

    from jarvis.ai.pkgs.alignn.pretrained import get_figshare_model
    from jarvis.core.atoms import Atoms

    # Download pre-trained ALIGNN model for formation energy
    model = get_figshare_model(model_name='jv_formation_energy_peratom_alignn')

    # Predict for SrTiO3
    atoms = Atoms.from_cif('SrTiO3.cif')
    result = model.predict_structure(atoms)

    print(f"Predicted Eform: {result:.4f} eV/atom") # ≈ −3.52 eV/atom

    # Run on a list of structures
    from jarvis.db.figshare import data as jdata
    dft_3d = jdata('dft_3d')[:100] # first 100 JARVIS-DFT entries
    predictions = [
        model.predict_structure(Atoms.from_dict(d['atoms']))
        for d in dft_3d
    ]

    7. Training Your Own ALIGNN Model — Formation Energy from JARVIS-DFT

    # Install: pip install alignn jarvis-tools dgl torch

    from alignn.train import train_dgl_multihead
    from alignn.config import TrainingConfig

    config = TrainingConfig(
        random_seed=123,
        epochs=100,
        n_train=44578,
        n_val=5572,
        n_test=5572,
        target='formation_energy_peratom',
        batch_size=64,
        alignn_layers=4, # number of ALIGNN+GCN pairs
        gcn_layers=4,
        atom_input_features=92,
        edge_input_features=80, # 80-dim GBF for bond distances
        triplet_input_features=40, # 40-dim GBF for bond angles
        hidden_features=256,
        output_features=1,
        learning_rate=1e-3,
        criterion='mse',
        optimizer='adamw',
        scheduler='onecycle',
    )

    train_dgl_multihead(config) # trains, saves best checkpoint
    ⚠️
    Data requirements for GNNs

    Training ALIGNN from scratch on JARVIS-DFT (~55,000 structures) requires a GPU and ~12 h of training. For smaller datasets (< 2,000 structures) fine-tune a pretrained model: freeze the graph convolution layers, retrain only the output head at a low learning rate (1e-4). This is the materials science equivalent of transfer learning from ImageNet.

    🔗
    App 12 — GNN Message Passing Visualiser
    Build a crystal graph by placing atoms, set bond distances and angles, watch message passing propagate in real time, compare CGCNN vs ALIGNN feature evolution, and predict formation energy from a simulated 3-atom graph.
    Open App →

    9. GNN vs CNN vs MLP — When to Use Each

    Property MLP CNN GNN
    Input type Fixed-length feature vector 2D/3D image (regular grid) Graph (variable atoms/bonds)
    Spatial awareness ✗ None ✓ 2D locality only ✓ Full topology
    Handles periodic BC ✗ (with tricks) ✓ (supercell edges)
    Equivariant to rotation Approx. with augmentation ✓ (invariant representations)
    Best materials use Composition-only models, fast screening Phase ID from diffraction/SEM images Structure → property (Eform, Eg, moduli)
    Data need ~100–1,000 ~1,000–10,000 images ~5,000–100,000 structures
    Key libraries scikit-learn, PyTorch torchvision, PyTorch PyTorch Geometric, DGL, JARVIS

    Quick Check

    1. In CGCNN, what does an edge in the crystal graph represent, and what features does it carry?

    • A. A unit cell vector; features are the lattice parameters a, b, c
    • B. A bond between two atoms within the cutoff radius; features are a Gaussian basis function expansion of the bond distance (~41 values)
    • C. A symmetry operation (rotation/reflection) of the space group
    • D. A shared electron pair; features are the bond order and bond polarity

    2. ALIGNN improves over CGCNN mainly because it explicitly encodes which structural information that CGCNN misses?

    • A. Atomic mass and nuclear charge
    • B. The full 3D coordinates of every atom in Cartesian space
    • C. Bond angles (∠i–j–k), captured through a second "line graph" where bonds become nodes and shared-atom pairs become edges
    • D. The magnetic moment on each transition metal atom

    3. You have 800 DFT-computed formation energies for a family of novel perovskites not in any public database. The best GNN strategy is:

    • A. Train a full ALIGNN from random weights on your 800 structures for 200 epochs
    • B. Use a random forest on Magpie composition features — GNNs are never worth it below 10,000 samples
    • C. Download a pretrained ALIGNN (e.g. from JARVIS), freeze the convolution layers, replace the output head, and fine-tune for 20–50 epochs at lr ≈ 1e-4 — transfer learning dramatically reduces data requirements
    • D. Convert structures to 2D diffraction images and use the CNN from Post 11 instead
    GNN Graph Neural Network CGCNN ALIGNN Message Passing Crystal Graph Formation Energy PyTorch Geometric JARVIS