Header Ads Widget

AI & Machine Learning for Materials Sciences

Last Posts

10/recent/ticker-posts

Post 11: Convolutional Neural Networks — Reading Crystal Images

How CNNs learn to detect local patterns in images — and why the same idea, applied to 2D electron density maps and diffraction patterns, powers modern materials property prediction from raw visual data.

Module 3
🔲
Key Operation

2D convolution + pooling

🔍
What CNNs Learn

Edges → textures → motifs → properties

📡
Materials Input

Diffraction patterns, electron density maps

🏗️
Famous Architecture

ResNet, VGG — adapted for materials

Posts 9 and 10 showed you how a fully connected network learns by adjusting every weight in response to a global error signal. But when the input is an image — say, an X-ray diffraction pattern or a 2D slice of an electron density map from a DFT calculation — a fully connected network faces two crippling problems: it has an enormous number of parameters, and it completely ignores the spatial structure of the image.

Convolutional Neural Networks (CNNs) solve both problems at once. Instead of connecting every pixel to every neuron, a CNN slides small, learnable filters across the image, detecting the same local pattern (an edge, a bright ring, a lattice fringe) wherever it appears. This gives CNNs translation equivariance — a peak in the diffraction pattern is recognised regardless of whether it sits at the top-left or bottom-right of the image.

🔬
Why images for materials?

Many experimental and computational characterisation techniques produce 2D images as their primary output: scanning electron microscopy (SEM), transmission electron microscopy (TEM), selected-area electron diffraction (SAED), powder X-ray diffractograms, and 2D slices of DFT charge density files (CHGCAR from VASP). A CNN trained on these images can predict crystal symmetry, phase, or electronic properties without requiring a manually engineered feature vector.

1. The Convolution Operation — Sliding a Learnable Filter

A convolutional layer applies a small matrix called a kernel (or filter) across every position of the input image. At each position, it computes the element-wise product of the kernel with the overlapping image patch and sums the results. This single number becomes one pixel in the feature map (or activation map) output.

🔲 2D Convolution Formula
(I ★ K)[i, j] = Σₘ Σₙ I[i+m, j+n] · K[m, n] + b

I = input image    K = kernel (e.g. 3×3 matrix)
b = bias (scalar, learned like any other parameter)
[i, j] = position of the output pixel in the feature map

The kernel K contains the learnable parameters — backpropagation
optimises them exactly as it does for fully connected weights.

A critical insight: the same kernel weights are reused at every position. A 3×3 kernel applied to a 64×64 image has only 9 parameters — not 64×64 = 4096. This weight sharing is what makes CNNs so efficient and prevents overfitting when data is limited.

Example: a 3×3 edge-detection kernel on a diffraction image

Consider the following kernel applied to a region of an X-ray diffraction pattern. The kernel is designed to highlight vertical intensity transitions — which correspond to sharp boundaries between diffraction rings and background:

Kernel · Patch → Output Pixel

Input patch (intensity)
12
15
80
10
13
85
11
14
82
Vertical edge kernel
−1
0
+1
−1
0
+1
−1
0
+1
Output pixel
+213
Strong vertical edge detected

In practice, a CNN never hand-designs these kernels. Backpropagation learns them automatically from the training data. A network trained on diffraction patterns of cubic, hexagonal, and orthorhombic crystals will learn kernels that respond to the symmetry-specific ring patterns that distinguish these phases.

2. Multiple Filters — Each One Learns a Different Pattern

A single convolutional layer typically applies many kernels simultaneously — one for each type of pattern the network should look for. Each kernel produces its own feature map. If you use 32 kernels, you get 32 feature maps stacked along a depth dimension. The next layer then applies kernels to this 3D volume, learning combinations of the patterns detected below.

🔲 Layer 1 Filters
Low-level: edges, intensity gradients, sharp transitions between bright diffraction spots and dark background.
🔶 Layer 2 Filters
Mid-level: ring segments, lattice fringe patterns, repeating bright spot arrays — combinations of Layer 1 edges.
⭐ Layer 3+ Filters
High-level: complete diffraction ring symmetry, Bragg peak arrangements — features that encode crystal class information.
💡
The hierarchy of representations

This automatic hierarchy — edges → textures → motifs → semantic features — mirrors how a trained crystallographer reads a diffraction pattern. The network learns to see "that ring spacing means FCC, that spot arrangement means BCC" without being told. The same principle makes CNNs powerful for SEM images (grain boundaries → grain morphology → phase identification) and TEM images (atomic columns → unit cell → space group).

3. Padding, Stride, and Output Size

Two hyperparameters control the geometry of each convolutional layer:

Hyperparameter What it does Typical choice Effect on output size
Padding (p) Adds zeros around the image border so the kernel can be applied to edge pixels p = 1 (for 3×3 kernel) — "same" padding Output = input size (no shrinkage)
Stride (s) How many pixels the kernel moves between applications s = 1 (dense) or s = 2 (downsampling) Output ≈ input / s
📐 Output Size Formula
Output size = ⌊ (Input − Kernel + 2·Padding) / Stride ⌋ + 1

Example: 64×64 image, 3×3 kernel, padding=1, stride=1
→ ⌊ (64 − 3 + 2) / 1 ⌋ + 1 = 64 ✓ (size preserved)

Example: 64×64 image, 3×3 kernel, padding=0, stride=2
→ ⌊ (64 − 3 + 0) / 2 ⌋ + 1 = 31 (downsampled)

4. Pooling — Compressing the Feature Maps

After each convolutional + activation layer, a pooling layer reduces the spatial size of the feature maps. The most common type is max pooling: slide a 2×2 window with stride 2 and keep only the maximum activation in each window. This halves the height and width, reducing computation and building in a degree of spatial invariance — a diffraction ring detected slightly off-centre still activates the same pooled feature.

⬇️ 2×2 Max Pooling (stride 2)
Input feature map (4×4): Output (2×2):

[ 1 3 | 2 4 ] [ 3 4 ]
[ 5 6 | 1 2 ] → max each → [ 6 8 ]
[ 7 1 | 8 3 ] 2×2 block
[ 2 4 | 5 6 ]

No parameters — pooling is a fixed operation, not a learned one.
Backpropagation routes the gradient only to the winning (max) pixel.
🔬
Global Average Pooling — the modern alternative

Many modern architectures (ResNet, EfficientNet) replace the final max pooling + fully connected layers with a single Global Average Pooling (GAP) layer that averages each feature map to a single number. For a diffraction pattern, this means "how much does each filter pattern appear anywhere in the image?" GAP dramatically reduces parameters and overfitting — important when training on small materials datasets of a few hundred diffraction images.

5. A Complete CNN Architecture for Crystal Phase Classification

Let's trace a realistic CNN designed to classify 2D powder diffractograms into three crystal systems: cubic, hexagonal, or orthorhombic. The input is a 64×64 grayscale image (intensity as a function of 2θ angle, rendered as a 2D pattern).

CNN Pipeline: 64×64 Diffraction Image → Crystal System

Input
64×64×1
Conv1
64×64×32
3×3, ReLU
Pool1
32×32×32
2×2 max
Conv2
32×32×64
3×3, ReLU
Pool2
16×16×64
2×2 max
Conv3
16×16×128
3×3, ReLU
GAP
128
global avg
Dense
64
ReLU
Output
3
Softmax

Parameter count — vs fully connected

Layer Output shape Parameters Notes
Conv1 (32 filters, 3×3) 64×64×32 32 × (3×3×1 + 1) = 320 9 weights + 1 bias per filter
Pool1 (2×2 max) 32×32×32 0 No learnable params
Conv2 (64 filters, 3×3) 32×32×64 64 × (3×3×32 + 1) = 18,496 Each filter sees all 32 input channels
Pool2 (2×2 max) 16×16×64 0
Conv3 (128 filters, 3×3) 16×16×128 128 × (3×3×64 + 1) = 73,856
Global Avg Pool 128 0 Avoids large FC layer
Dense (64) 64 128×64 + 64 = 8,256
Output (3) 3 64×3 + 3 = 195 Softmax probabilities
Total CNN ~101,000
Fully connected equivalent 64×64 × 64 = 262,144 first layer alone ~2.6× more, just for one layer

6. Training — Backpropagation Through Convolutional Layers

Training a CNN uses exactly the same backpropagation algorithm from Post 10, extended to convolutional layers. The key insight is that the gradient of the loss with respect to a kernel weight k is the sum of contributions from every position where that kernel was applied:

🔄 Gradient of the loss w.r.t. a kernel weight k[m,n]
∂L/∂K[m,n] = Σᵢ Σⱼ δ[i,j] · I[i+m, j+n]

δ[i,j] = gradient flowing back to position (i,j) of the feature map
I[i+m, j+n] = input patch value at that position

This is itself a correlation — the same operation as the forward pass.
PyTorch's autograd computes this automatically via loss.backward().
⚙️
What makes CNN training different from Post 10?

In a fully connected network, each weight gets one gradient contribution per training sample. In a CNN, each kernel weight gets as many contributions as the number of positions it was applied to (e.g. 64×64 = 4096 for a 64×64 image). These are summed to give the total gradient. This is why convolutional layers are computationally expensive on large images, and why GPU acceleration (via CUDA) is standard practice.

7. Residual Connections — Solving the Vanishing Gradient Problem for Deep CNNs

Post 10 introduced the dying ReLU / vanishing gradient problem for deep networks. When CNNs are stacked many layers deep (20, 50, 100 layers), gradients can shrink to near-zero before reaching the early layers, making training very slow or impossible. The residual connection (skip connection), introduced in ResNet (He et al., 2016), is the standard fix.

🔗 Residual Block Formula
Output = F(x) + x

x = the input to the block
F(x) = Conv → BN → ReLU → Conv → BN (the residual to learn)

Gradient path 1: ∂L/∂x through F(x) (may vanish)
Gradient path 2: ∂L/∂x through the skip → always = ∂L/∂output (never vanishes)

The skip connection provides a "gradient superhighway" to all early layers.
🔬
ResNet applied to electron microscopy images

Park et al. (2020) and subsequent work from several groups fine-tuned ResNet-50 on TEM images of perovskite oxides to classify oxygen vacancy ordering patterns — a task that previously required hours of manual expert analysis per image. A pretrained ResNet (trained on ImageNet) was adapted to the task in a few hundred TEM images using transfer learning: keep the early convolutional layers (which already detect edges and textures), retrain only the later layers and classifier head on materials images. This is the recommended approach when your materials image dataset is small.

8. Transfer Learning — Standing on ImageNet's Shoulders

Training a CNN from scratch requires thousands of labelled images. Most materials research groups cannot produce this volume of annotated data. The solution is transfer learning: start from a network already trained on a large dataset (ImageNet, with 1.2 million images), and fine-tune it on your smaller materials dataset.

🧊

Frozen layers

Freeze all convolutional layers — keep their ImageNet-learned filters intact. Only train the final classification head (dense + softmax). Works when materials images are structurally similar to natural images.

🌡️

Fine-tuning

Unfreeze the last 1–2 conv blocks and train them at a very low learning rate (1×10⁻⁵). Allows the network to adapt high-level filters to materials-specific patterns. Best for SEM/TEM images.

🔄

Full retraining

Retrain all layers from pretrained weights at a moderate learning rate. Use only when you have ≥ 5000 labelled materials images and the domain is very different from natural images (e.g. X-ray diffractograms).

9. PyTorch Implementation — CNN for Diffraction Image Classification

Building the architecture from scratch

import torch, torch.nn as nn

class DiffractionCNN(nn.Module):
    def __init__(self):
        super().__init__()
        # ── Convolutional backbone ─────────────────────────
        self.features = nn.Sequential(
            # Block 1: 64×64×1 → 64×64×32 → 32×32×32
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2), # → 32×32×32

            # Block 2: 32×32×32 → 32×32×64 → 16×16×64
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2), # → 16×16×64

            # Block 3: 16×16×64 → 16×16×128
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
        )
        self.gap = nn.AdaptiveAvgPool2d(1) # → 128×1×1
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128, 64), nn.ReLU(), nn.Dropout(0.3),
            nn.Linear(64, 3) # cubic / hexagonal / orthorhombic
        )

    def forward(self, x):
        x = self.features(x) # convolutional backbone
        x = self.gap(x) # global average pooling
        return self.classifier(x)

model = DiffractionCNN()
print(sum(p.numel() for p in model.parameters()), 'parameters') # ~101k

Training loop with data augmentation

from torchvision import transforms, datasets
from torch.utils.data import DataLoader

# Augmentation — rotate diffraction patterns (any rotation is physically valid)
train_tf = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize((64, 64)),
    transforms.RandomRotation(180), # diffraction has rotational symmetry
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

train_ds = datasets.ImageFolder('data/diffraction/train', transform=train_tf)
train_dl = DataLoader(train_ds, batch_size=32, shuffle=True)

opt = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=50)

for epoch in range(50):
    model.train()
    for imgs, labels in train_dl:
        opt.zero_grad()
        loss = loss_fn(model(imgs), labels)
        loss.backward()
        opt.step()
    scheduler.step()

# Inspect what Conv1 learned (first 4 filters)
filters = model.features[0].weight.data # shape: (32, 1, 3, 3)
print(filters[:4].squeeze()) # print first 4 kernels

Transfer learning from ResNet-18 — recommended for small datasets

from torchvision.models import resnet18, ResNet18_Weights

# Load pretrained weights
backbone = resnet18(weights=ResNet18_Weights.DEFAULT)

# Adapt the first conv for grayscale input
backbone.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)

# Freeze all parameters except the last layer
for param in backbone.parameters():
    param.requires_grad = False

# Replace classifier head: 512 → 3 classes
backbone.fc = nn.Linear(512, 3)

# Only the new layers + conv1 are trainable
for param in backbone.conv1.parameters():
    param.requires_grad = True

trainable = sum(p.numel() for p in backbone.parameters() if p.requires_grad)
print(f"Trainable parameters: {trainable:,}") # ~4,000 vs 11M total
🔲
App 11 — CNN Filter Visualiser
Watch convolution kernels slide across a synthetic diffraction pattern in real time. See how filter responses build into feature maps, adjust kernel weights manually, compare pooling strategies, and classify simulated cubic vs hexagonal vs orthorhombic patterns.
Open App →

10. CNN vs Fully Connected — When to Use Each

Property Fully Connected (MLP) CNN
Input type Flat feature vector (e.g. 3 elemental descriptors) 2D image (diffraction, electron density, SEM)
Spatial structure Ignored — all features treated equally Exploited — nearby pixels share kernels
Parameters Grows as input_size × hidden_size Grows as n_filters × kernel² (much smaller)
Translation invariance ✓ (after pooling)
Materials use case Band gap from composition vectors, Magpie features Phase ID from diffraction images, defect detection in TEM
Data requirement Can work with ~100–1000 samples Needs ~1000+ images (or transfer learning for fewer)

Quick Check

1. A convolutional layer uses 64 filters of size 3×3 on an input with 32 channels. How many learnable parameters does this layer have (including biases)?

  • A. 64 × 3 × 3 = 576
  • B. 64 × 3 × 3 × 32 = 18,432
  • C. 64 × (3 × 3 × 32 + 1) = 18,496 (including one bias per filter)
  • D. 3 × 3 × 32 = 288 (one kernel shared by all filters)

2. Why does rotating a diffraction image by 45° not fool a well-trained CNN classifier, even if the training set only contained upright images?

  • A. The softmax output automatically normalises for rotation
  • B. Max pooling and the hierarchical filter structure make CNNs approximately rotation-equivariant; data augmentation with random rotations during training makes them fully invariant
  • C. Batch normalisation removes the rotational information before the classifier
  • D. CNNs cannot handle rotated images — this would reduce accuracy significantly

3. You have 200 labelled SEM images of three aluminium alloy microstructures. Which CNN training strategy is most appropriate?

  • A. Train a 20-layer ResNet from random weights on these 200 images
  • B. Use only Global Average Pooling with no convolutional layers
  • C. Freeze all layers of a pretrained ResNet-18 except the final classifier, and fine-tune on the 200 images at a low learning rate — transfer learning is essential at this dataset size
  • D. Convert the images to feature vectors and use a fully connected network instead
CNN Convolution Feature Maps Pooling Transfer Learning ResNet Diffraction Images PyTorch