Introduction of Graph Neural Networks for Spatial Transcriptomics
A comprehensive guide to understanding and applying GNNs to spatial transcriptomics data, from foundational concepts to cutting-edge research implementations
Tech Stack
Tags
Introduction of Graph Neural Networks for Spatial Transcriptomics
A deep dive into applying Graph Neural Networks (GNNs) to spatial transcriptomics research, bridging deep learning theory with cutting-edge biological applications.
Table of Contents
- Introduction
- From CNNs to GNNs
- Core GNN Architectures
- Graph Construction Strategies
- State-of-the-Art Methods in Spatial Transcriptomics
- Implementation Guide
- Practical Considerations
Introduction
Spatial transcriptomics (ST) captures gene expression while preserving the physical location of cells in tissue. This spatial context creates a natural graph structure where:
- Nodes = spots/cells with gene expression vectors
- Edges = spatial proximity or molecular similarity
- Task = identify tissue domains, cell types, or spatial patterns
Graph Neural Networks provide a powerful framework to leverage this structure, enabling message-passing that respects both spatial topology and molecular relationships.
From CNNs to GNNs
Why CNNs Can't Handle Irregular Grids
Convolutional Neural Networks excel at structured data (images) because:
- Local connectivity: Each pixel has a fixed neighborhood (e.g., 3×3)
- Weight sharing: The same filter slides across all positions, making parameter efficient
- Translation invariance: Patterns learned anywhere apply everywhere
The convolution operation at pixel is:
The problem for spatial transcriptomics: Spots are irregularly spaced and often have varying numbers of neighbors. We can't use a fixed 3×3 kernel.
Message Passing: The Key Insight
GNNs generalize convolution to arbitrary graph structures through message passing:
Where:
- = feature vector of node at layer
- = neighbors of node
- = learnable weight matrix
- = activation function (ReLU, etc.)
This is conceptually similar to CNN convolution, but:
- Neighbors are defined by edges, not spatial proximity on a grid
- Each node can have a variable number of neighbors
- Weights are shared across all nodes, just like CNN filters
Core GNN Architectures
1. Graph Convolutional Network (GCN)
Paper: Kipf & Welling (2017)
Core Idea: Normalize neighbor aggregation by degree to prevent scale issues.
Mathematical Formulation
Given:
- : node features (N spots, F genes)
- : adjacency matrix
- : degree matrix (diagonal, where )
Add self-loops and normalize:
The GCN layer performs symmetric normalization:
Intuition: The symmetric normalization creates a normalized adjacency that averages features fairly across neighbors, preventing high-degree nodes from dominating.
PyTorch Implementation (Matrix Form)
import torch
import torch.nn as nn
def normalize_adj(edge_index, num_nodes):
"""
Build symmetric normalized adjacency: D^{-1/2} A D^{-1/2}
Returns sparse tensor indices and values.
"""
i, j = edge_index
# Add self-loops
self_loops = torch.arange(num_nodes, device=edge_index.device)
i = torch.cat([i, self_loops])
j = torch.cat([j, self_loops])
# Compute degree
deg = torch.bincount(j, minlength=num_nodes).float()
deg_inv_sqrt = deg.clamp(min=1).pow(-0.5)
# Symmetric normalization weights
vals = deg_inv_sqrt[i] * deg_inv_sqrt[j]
return torch.stack([i, j], dim=0), vals
class GCNLayer(nn.Module):
"""
Graph Convolutional Layer with sparse adjacency multiply.
"""
def __init__(self, in_channels, out_channels, bias=True):
super().__init__()
self.lin = nn.Linear(in_channels, out_channels, bias=bias)
def forward(self, x, edge_index):
N = x.size(0)
idx, vals = normalize_adj(edge_index, N)
A_norm = torch.sparse_coo_tensor(idx, vals, (N, N), device=x.device)
x = self.lin(x) # [N, F_out]
x = torch.sparse.mm(A_norm, x) # Sparse matrix multiply
return torch.relu(x)
# Full GCN Model
class GCN(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2):
super().__init__()
self.layers = nn.ModuleList()
self.layers.append(GCNLayer(in_channels, hidden_channels))
for _ in range(num_layers - 2):
self.layers.append(GCNLayer(hidden_channels, hidden_channels))
self.layers.append(GCNLayer(hidden_channels, out_channels))
def forward(self, x, edge_index):
for i, layer in enumerate(self.layers[:-1]):
x = layer(x, edge_index)
x = torch.relu(x)
x = torch.dropout(x, p=0.5, train=self.training)
x = self.layers[-1](x, edge_index)
return xStrengths:
- Simple, efficient, mathematically principled
- Fast on small-to-medium graphs
- Smooth aggregation good for homogeneous tissue
Weaknesses:
- Fixed (isotropic) weights—all neighbors treated equally
- Transductive (needs full graph at training)
- Over-smoothing with many layers
2. GraphSAGE
Paper: Hamilton et al. (2017)
Core Idea: Sample neighbors and use robust aggregators (mean, LSTM, max-pool) to scale to large graphs and enable inductive learning (generalize to unseen nodes/graphs).
Update Rule
Key differences from GCN:
- Concatenate self-feature with aggregated neighbors (preserves identity)
- Sample a fixed number of neighbors (e.g., 10) instead of using all
- Aggregator choices: mean, max-pool, LSTM
PyTorch Implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
class SAGEConv(nn.Module):
"""
GraphSAGE layer with mean aggregation.
"""
def __init__(self, in_channels, out_channels, normalize=True):
super().__init__()
self.lin_self = nn.Linear(in_channels, out_channels)
self.lin_neigh = nn.Linear(in_channels, out_channels)
self.normalize = normalize
def forward(self, x, edge_index, num_neighbors_sample=10):
N = x.size(0)
# Sample neighbors (for scalability)
src, dst = edge_index
# Group by destination
neighbors = {}
for s, d in zip(src.tolist(), dst.tolist()):
neighbors.setdefault(d, []).append(s)
# Sample and aggregate
neigh_features = []
for i in range(N):
if i in neighbors and len(neighbors[i]) > 0:
sampled = torch.tensor(
neighbors[i][:num_neighbors_sample],
device=x.device
)
neigh_features.append(x[sampled].mean(dim=0))
else:
neigh_features.append(torch.zeros(x.size(1), device=x.device))
neigh_agg = torch.stack(neigh_features) # [N, F]
# Combine self and neighbor
out = self.lin_self(x) + self.lin_neigh(neigh_agg)
if self.normalize:
out = F.normalize(out, p=2, dim=1)
return out
class GraphSAGE(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2):
super().__init__()
self.layers = nn.ModuleList()
self.layers.append(SAGEConv(in_channels, hidden_channels))
for _ in range(num_layers - 2):
self.layers.append(SAGEConv(hidden_channels, hidden_channels))
self.layers.append(SAGEConv(hidden_channels, out_channels))
def forward(self, x, edge_index):
for layer in self.layers[:-1]:
x = layer(x, edge_index)
x = F.relu(x)
x = F.dropout(x, p=0.5, training=self.training)
x = self.layers[-1](x, edge_index)
return xStrengths:
- Scales to massive graphs via neighbor sampling
- Inductive: can generalize to new slides without retraining
- Multiple aggregator options
Weaknesses:
- Sampling introduces variance
- More hyperparameters (sample size per layer)
- Still treats all sampled neighbors equally
3. Graph Attention Network (GAT)
Paper: Veličković et al. (2018)
Core Idea: Learn per-edge attention weights so the model can focus on important neighbors and ignore noisy ones.
Attention Mechanism
For each edge :
Normalize with softmax over all neighbors:
Aggregate with learned weights:
Multi-head attention (like Transformers): Run independent attention heads and concatenate:
PyTorch Implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
class GATLayer(nn.Module):
"""
Graph Attention Layer (single head).
"""
def __init__(self, in_channels, out_channels, dropout=0.6, alpha=0.2):
super().__init__()
self.W = nn.Linear(in_channels, out_channels, bias=False)
self.a = nn.Parameter(torch.zeros(size=(2 * out_channels, 1)))
self.leakyrelu = nn.LeakyReLU(alpha)
self.dropout = dropout
nn.init.xavier_uniform_(self.a.data, gain=1.414)
def forward(self, x, edge_index):
# Linear transformation
Wh = self.W(x) # [N, out_channels]
N = Wh.size(0)
# Compute attention coefficients
src, dst = edge_index
# Concatenate source and destination features
Wh_src = Wh[src] # [E, out_channels]
Wh_dst = Wh[dst] # [E, out_channels]
concat_features = torch.cat([Wh_src, Wh_dst], dim=1) # [E, 2*out_channels]
# Attention mechanism
e = self.leakyrelu(concat_features @ self.a).squeeze() # [E]
# Normalize per destination node
attention = torch.zeros(N, device=x.device)
attention_weights = torch.zeros_like(e)
for i in range(N):
mask = (dst == i)
if mask.sum() > 0:
e_neighbors = e[mask]
alpha = F.softmax(e_neighbors, dim=0)
attention_weights[mask] = alpha
# Apply dropout to attention
attention_weights = F.dropout(attention_weights, self.dropout, training=self.training)
# Aggregate
out = torch.zeros_like(Wh)
for i in range(N):
mask = (dst == i)
if mask.sum() > 0:
out[i] = (attention_weights[mask].unsqueeze(1) * Wh[src[mask]]).sum(dim=0)
return out
class GAT(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels,
num_heads=8, dropout=0.6):
super().__init__()
self.dropout = dropout
# Multi-head first layer
self.attentions = nn.ModuleList([
GATLayer(in_channels, hidden_channels, dropout=dropout)
for _ in range(num_heads)
])
# Single-head output layer
self.out_att = GATLayer(
hidden_channels * num_heads,
out_channels,
dropout=dropout
)
def forward(self, x, edge_index):
# First layer: concatenate multi-head outputs
x = torch.cat([att(x, edge_index) for att in self.attentions], dim=1)
x = F.dropout(x, self.dropout, training=self.training)
x = F.elu(x)
# Output layer
x = self.out_att(x, edge_index)
return xStrengths:
- Learns which neighbors are important (anisotropic)
- Better at handling heterogeneous neighborhoods
- Respects sharp boundaries (e.g., tumor-stroma)
Weaknesses:
- More parameters (attention + multi-heads)
- Slower than GCN
- Can be unstable (needs careful tuning of dropout, LR)
Architecture Comparison Table
| Model | Neighbor Weights | Sampling | Inductive? | Best Use Case |
|---|---|---|---|---|
| GCN | Fixed by normalized adjacency | Full graph (no sampling) | No (transductive) | Smooth tissue, clean spatial graphs, small-medium datasets |
| GraphSAGE | Fixed aggregator (mean/max/LSTM) | Neighborhood sampling | Yes | Large graphs, multiple slides, production systems |
| GAT | Learned per-edge attention | Full neighborhood (or can sample) | Often yes | Heterogeneous tissue, sharp boundaries, variable neighbor quality |
Visual Comparison:
GCN: All neighbors weighted equally (by degree normalization)
[Spot] ← (0.25, 0.25, 0.25, 0.25) ← [4 neighbors]
GraphSAGE: Sample subset, aggregate with mean/max
[Spot] ← sample 2 → mean([n1, n2])
GAT: Learn attention per edge
[Spot] ← (α₁=0.5, α₂=0.3, α₃=0.15, α₄=0.05) ← [4 neighbors]
↑ Learned weights
Graph Construction Strategies
Critical insight: Graph topology often matters more than the GNN architecture.
1. Spatial Proximity Graphs
k-Nearest Neighbors (kNN)
Connect each spot to its k closest spatial neighbors.
import numpy as np
from sklearn.neighbors import NearestNeighbors
import torch
def knn_graph(coords, k=6):
"""
Build kNN graph from spatial coordinates.
Args:
coords: [N, 2] array of (x, y) positions
k: number of neighbors
Returns:
edge_index: [2, E] tensor in COO format
"""
nbrs = NearestNeighbors(n_neighbors=k+1, algorithm='kd_tree').fit(coords)
distances, indices = nbrs.kneighbors(coords)
# Exclude self (first neighbor)
src = np.repeat(np.arange(len(coords)), k)
dst = indices[:, 1:].reshape(-1)
# Make undirected
edge_index = np.stack([src, dst], axis=0)
edge_index = np.concatenate([edge_index, edge_index[::-1]], axis=1)
edge_index = np.unique(edge_index, axis=1)
return torch.as_tensor(edge_index, dtype=torch.long)Pros: Simple, controllable density
Cons: Fixed k ignores tissue density variation
Effect of k on smoothing:
- Small
k(4-6): Preserves sharp boundaries, less smoothing - Large
k(12-20): More smoothing, boundaries blur, risk of over-smoothing
Radius (ε-ball) Graph
Connect all spots within distance ε.
def radius_graph(coords, radius=100.0):
"""
Build radius graph: connect nodes within distance threshold.
"""
from sklearn.neighbors import radius_neighbors_graph
A = radius_neighbors_graph(coords, radius, mode='connectivity')
edge_index = torch.as_tensor(np.array(A.nonzero()), dtype=torch.long)
return edge_indexPros: Density-adaptive, physically meaningful
Cons: Can create disconnected components if ε too small
Delaunay Triangulation
Build triangles connecting nearest points without overlap—preserves natural geometry.
from scipy.spatial import Delaunay
import numpy as np
def delaunay_graph(coords):
"""
Build graph from Delaunay triangulation.
Used in SpaGCN for natural spatial connectivity.
"""
tri = Delaunay(coords)
edges = set()
for simplex in tri.simplices:
for i in range(3):
edge = tuple(sorted([simplex[i], simplex[(i+1)%3]]))
edges.add(edge)
edges = np.array(list(edges)).T
# Make undirected
edge_index = np.concatenate([edges, edges[::-1]], axis=1)
return torch.as_tensor(edge_index, dtype=torch.long)Pros: Natural geometric structure, used in SpaGCN Cons: Can create long edges in sparse regions
2. Expression Similarity Graphs
Connect spots with similar gene expression, independent of distance.
def expression_similarity_graph(X, k=10, metric='cosine'):
"""
Build kNN graph in expression space.
Args:
X: [N, genes] expression matrix
k: number of neighbors
metric: 'cosine' or 'euclidean'
"""
from sklearn.neighbors import NearestNeighbors
nbrs = NearestNeighbors(n_neighbors=k+1, metric=metric).fit(X)
distances, indices = nbrs.kneighbors(X)
src = np.repeat(np.arange(len(X)), k)
dst = indices[:, 1:].reshape(-1)
edge_index = np.stack([src, dst], axis=0)
# Undirected
edge_index = np.concatenate([edge_index, edge_index[::-1]], axis=1)
edge_index = np.unique(edge_index, axis=1)
return torch.as_tensor(edge_index, dtype=torch.long)Pros: Captures functional/molecular similarity Cons: May connect distant regions; ignores spatial structure
3. Hybrid (Multi-view) Graphs
Combine spatial and expression information.
Weighted combination:
Learnable weighting:
class MultiViewGNN(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.gcn_spatial = GCNLayer(in_channels, out_channels)
self.gcn_expression = GCNLayer(in_channels, out_channels)
self.alpha = nn.Parameter(torch.tensor(0.5)) # Learnable weight
def forward(self, x, edge_index_spatial, edge_index_expr):
h_spatial = self.gcn_spatial(x, edge_index_spatial)
h_expr = self.gcn_expression(x, edge_index_expr)
alpha = torch.sigmoid(self.alpha) # Ensure [0,1]
return alpha * h_spatial + (1 - alpha) * h_exprUsed in: STAGATE, conST, GraphST
4. Image-Based Graphs (Histology Integration)
If H&E images are available, use morphological similarity.
def image_weighted_graph(coords, image_features, k=6, alpha=0.5):
"""
Weighted graph combining spatial distance and image feature similarity.
Args:
coords: [N, 2] spatial positions
image_features: [N, D] CNN features from histology patches
alpha: weight between spatial and image (0=spatial only, 1=image only)
"""
from sklearn.neighbors import NearestNeighbors
from sklearn.metrics.pairwise import cosine_similarity
# Spatial kNN
nbrs_spatial = NearestNeighbors(n_neighbors=k+1).fit(coords)
dist_spatial, idx_spatial = nbrs_spatial.kneighbors(coords)
# Image similarity
img_sim = cosine_similarity(image_features)
edges = []
weights = []
for i in range(len(coords)):
for j_idx, j in enumerate(idx_spatial[i, 1:]): # Skip self
spatial_sim = 1.0 / (1.0 + dist_spatial[i, j_idx+1])
image_sim_val = img_sim[i, j]
combined_weight = alpha * spatial_sim + (1 - alpha) * image_sim_val
edges.append([i, j])
weights.append(combined_weight)
edge_index = torch.tensor(edges, dtype=torch.long).T
edge_weight = torch.tensor(weights, dtype=torch.float)
return edge_index, edge_weightUsed in: SpaGCN, conST, MERGE
State-of-the-Art Methods in Spatial Transcriptomics
Method Comparison Table
| Method | Year | Graph Type | GNN Architecture | Key Innovation | Task |
|---|---|---|---|---|---|
| SpaGCN | 2021 | Spatial + histology weighted | GCN | Delaunay + image features, domain detection | Clustering, SVG detection |
| STAGATE | 2022 | Spatial kNN | Graph Attention Autoencoder | Learned attention, autoencoder framework | Clustering, denoising, trajectory |
| SEDR | 2021 | Spatial kNN | VGAE + deep autoencoder | Dual embedding (expression + spatial VGAE) | Clustering, embedding |
| GraphST | 2023 | Multi-view (spatial + expr) | GCN + contrastive learning | Self-supervised contrastive, multi-sample integration | Clustering, deconvolution |
| STdGCN | 2023 | Hybrid (spot + pseudo-spot) | GCN | Integrate scRNA reference via hybrid graph | Deconvolution |
| stAA | 2024 | Spatial kNN | Adversarial VGAE | WGAN regularization on embeddings | Clustering |
| stMCDI | 2024 | Spatial | GNN encoder in diffusion model | Diffusion-based imputation with GNN | Missing value imputation |
| MERGE | 2023 | Hierarchical image patches | Multi-scale GNN | Image → gene prediction via hierarchical graph | Prediction from histology |
Deep Dive: SpaGCN
Paper: Hu et al., Nature Methods (2021)
Problem: Identify spatial domains using both gene expression and tissue morphology.
Graph Construction:
- Build Delaunay triangulation from spot coordinates
- Weight edges by combination of:
- Spatial distance
- Histology image similarity (compare H&E patches around each spot)
GNN: Standard GCN with weighted adjacency
Loss: Combines:
- Clustering loss (Louvain or k-means on embeddings)
- Spatial regularization (neighbors should have similar cluster assignments)
Code Sketch:
# Weighted adjacency from spatial + image
A_weighted = spatial_weight * image_weight
# GCN embedding
H = gcn_layers(X, A_weighted)
# Clustering
clusters = louvain_clustering(H)
# Find spatially variable genes per domain
svg_genes = differential_expression_per_domain(clusters)Strengths: Multi-modal, interpretable, works well on Visium data Limitations: Static graph, sensitive to weighting hyperparameters
Deep Dive: STAGATE
Paper: Dong & Chen, Nature Communications (2022)
Problem: Learn representations robust to noise, preserve spatial structure, enable downstream tasks (clustering, pseudotime).
Architecture: Graph Attention Autoencoder
- Encoder: Multi-layer GAT → latent embedding
Z - Decoder: Reconstruct both expression
Xand adjacencyA
Key Innovation: Attention mechanism learns which neighbors are relevant, mitigating over-smoothing.
Loss:
Where:
- = MSE for expression reconstruction
- = binary cross-entropy for adjacency reconstruction
Code Outline:
class STAGATE(nn.Module):
def __init__(self, n_genes, hidden_dim, latent_dim):
super().__init__()
self.encoder = nn.ModuleList([
GATLayer(n_genes, hidden_dim),
GATLayer(hidden_dim, latent_dim)
])
self.decoder_expr = nn.Linear(latent_dim, n_genes)
self.decoder_graph = InnerProductDecoder()
def forward(self, x, edge_index):
# Encode
h = x
for layer in self.encoder:
h = F.elu(layer(h, edge_index))
# Decode
x_recon = self.decoder_expr(h)
adj_recon = torch.sigmoid(h @ h.T) # Inner product for adjacency
return x_recon, adj_recon, hStrengths: Adaptive smoothing, rich embeddings for multiple tasks Limitations: More hyperparameters, potential overfitting on small datasets
Deep Dive: GraphST (Contrastive Learning)
Paper: Long et al., Nature Communications (2023)
Problem: Integrate multiple tissue sections, perform clustering and deconvolution.
Key Idea: Self-supervised contrastive learning on graph embeddings.
Method:
- Build two views of the same graph (e.g., different augmentations or spatial vs. expression graphs)
- Encode both with GCN
- Contrastive loss: pull together embeddings from same spot, push apart different spots
Where is the augmented view of spot .
Benefits:
- Learns robust representations without labeled data
- Aligns multiple slides in shared embedding space
Code Sketch:
# Create two views
edge_index_view1 = knn_graph(coords, k=6)
edge_index_view2 = expression_similarity_graph(X, k=10)
# Encode both
z1 = gnn_encoder(X, edge_index_view1)
z2 = gnn_encoder(X, edge_index_view2)
# Contrastive loss (InfoNCE)
loss = contrastive_loss(z1, z2, temperature=0.5)Implementation Guide
Complete Pipeline: Spatial Domain Detection
Here's a complete, runnable pipeline from raw Visium data to domain detection using a GNN.
Step 1: Load Data
import scanpy as sc
import numpy as np
import torch
from sklearn.preprocessing import StandardScaler
# Load Visium data (standard scanpy format)
adata = sc.read_visium("path/to/visium/data")
# Extract coordinates and expression
coords = adata.obsm['spatial'] # [N, 2]
X = adata.X.toarray() if hasattr(adata.X, 'toarray') else adata.X # [N, genes]
# Highly variable genes
sc.pp.highly_variable_genes(adata, n_top_genes=2000)
X = X[:, adata.var['highly_variable']]
# Normalize
X = StandardScaler().fit_transform(X)
# To tensors
X_tensor = torch.tensor(X, dtype=torch.float32)
coords_np = np.array(coords)Step 2: Build Graph
from sklearn.neighbors import NearestNeighbors
def build_spatial_graph(coords, k=6):
nbrs = NearestNeighbors(n_neighbors=k+1).fit(coords)
_, indices = nbrs.kneighbors(coords)
src = np.repeat(np.arange(len(coords)), k)
dst = indices[:, 1:].reshape(-1)
edge_index = np.stack([src, dst], axis=0)
# Undirected
edge_index = np.concatenate([edge_index, edge_index[::-1]], axis=1)
edge_index = np.unique(edge_index, axis=1)
return torch.tensor(edge_index, dtype=torch.long)
edge_index = build_spatial_graph(coords_np, k=6)Step 3: Define Model
class SpatialGNN(nn.Module):
def __init__(self, n_genes, hidden_dim=128, latent_dim=32, n_clusters=7):
super().__init__()
self.encoder = nn.ModuleList([
GCNLayer(n_genes, hidden_dim),
GCNLayer(hidden_dim, latent_dim)
])
self.cluster_head = nn.Linear(latent_dim, n_clusters)
def forward(self, x, edge_index):
h = x
for layer in self.encoder:
h = layer(h, edge_index)
h = F.relu(h)
h = F.dropout(h, p=0.3, training=self.training)
logits = self.cluster_head(h)
return h, logits
model = SpatialGNN(n_genes=X.shape[1], n_clusters=7)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)Step 4: Training Loop
from sklearn.cluster import KMeans
def train_epoch(model, x, edge_index, optimizer):
model.train()
optimizer.zero_grad()
# Forward
embeddings, logits = model(x, edge_index)
# Pseudo-labels from k-means (self-supervised)
with torch.no_grad():
kmeans = KMeans(n_clusters=7, random_state=42)
pseudo_labels = kmeans.fit_predict(embeddings.cpu().numpy())
pseudo_labels = torch.tensor(pseudo_labels, dtype=torch.long)
# Cross-entropy loss
loss = F.cross_entropy(logits, pseudo_labels)
loss.backward()
optimizer.step()
return loss.item()
# Train
for epoch in range(100):
loss = train_epoch(model, X_tensor, edge_index, optimizer)
if epoch % 10 == 0:
print(f"Epoch {epoch}, Loss: {loss:.4f}")Step 5: Extract Clusters and Visualize
model.eval()
with torch.no_grad():
embeddings, logits = model(X_tensor, edge_index)
clusters = logits.argmax(dim=1).cpu().numpy()
# Add to adata
adata.obs['gnn_cluster'] = clusters
# Visualize on tissue
import scanpy as sc
sc.pl.spatial(adata, color='gnn_cluster', spot_size=150)Practical Considerations
1. Over-smoothing Problem
Issue: With many GNN layers, all node features converge to the same value.
Why: Each layer mixes neighbors' features, leading to exponential diffusion.
Solutions:
-
Limit layers: 2-3 layers often enough for ST (tissue neighborhoods are local)
-
Skip connections: Add residual connections
h_new = gcn_layer(h) + h # Residual -
Adaptive depth (DropEdge, PairNorm): Randomly drop edges or normalize layer outputs
Rule of thumb for ST:
- Small tissue sections: 2 layers
- Whole-slide: 3-4 layers max
- Multi-sample integration: Consider global pooling instead of deeper GNN
2. Choosing Hyperparameter k (Number of Neighbors)
Trade-off:
- Low k (4-6): Sharp boundaries, less smoothing, more noise sensitivity
- High k (12-20): Smoother, more stable, but blurs boundaries
Guideline:
- Delaunay: Natural choice, typically
- Brain cortex / layered tissue: Lower k to preserve layers
- Tumor heterogeneity: Mid-range k (6-8)
- Dense spot arrays (Visium HD): Higher k (10-15)
Experiment: Try and check:
- Silhouette score on clusters
- Agreement with known anatomical regions
3. Normalization and Preprocessing
Critical steps:
- Gene selection: Use highly variable genes (1000-3000)
- Expression normalization: Log1p + standard scaling
- Graph normalization: Symmetric (built into GCN)
- Feature scaling: StandardScaler before training
import scanpy as sc
from sklearn.preprocessing import StandardScaler
# Standard pipeline
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
sc.pp.highly_variable_genes(adata, n_top_genes=2000)
adata = adata[:, adata.var.highly_variable]
X = StandardScaler().fit_transform(adata.X.toarray())4. Evaluation Metrics
For clustering:
- Adjusted Rand Index (ARI): Compare with ground-truth annotations
- Normalized Mutual Information (NMI)
- Silhouette score: Cluster separation in embedding space
For deconvolution:
- RMSE vs. known proportions (if synthetic data)
- Correlation with reference cell-type markers
For imputation:
- RMSE on held-out genes
- Pearson correlation with ground truth
from sklearn.metrics import adjusted_rand_score, silhouette_score
# Clustering evaluation
ari = adjusted_rand_score(true_labels, predicted_clusters)
silhouette = silhouette_score(embeddings, predicted_clusters)
print(f"ARI: {ari:.3f}, Silhouette: {silhouette:.3f}")5. Common Pitfalls
| Problem | Symptom | Solution |
|---|---|---|
| Over-smoothing | All embeddings identical | Reduce layers, add skip connections |
| Disconnected graph | NaN in training | Check graph connectivity, increase k or |
| Unstable GAT | Loss spikes | Lower learning rate, increase dropout on attention |
| Poor clustering | Random assignments | Better graph construction, tune k, check normalization |
| High memory | OOM errors | Use mini-batching (GraphSAGE), neighbor sampling |
Advanced Topics
1. Hypergraph Neural Networks
Standard GNN: Pairwise edges (connects 2 nodes) Hypergraph: Hyperedges connect nodes simultaneously
Use case in ST: Group of spots in a tissue domain share a hyperedge
Example: Hypergraph Neural Networks Reveal Spatial Domains
2. Hierarchical and Multi-Scale Graphs
Idea: Build graphs at multiple resolutions
- Fine-scale: spot-to-spot
- Coarse-scale: domain-to-domain (via clustering)
Benefits: Capture both local and global patterns
Example: MERGE builds hierarchical image patch graphs
3. Graph Transformers
Replace message-passing with full self-attention over all nodes.
Pros: No over-smoothing, long-range dependencies Cons: complexity
Emerging work: Graphormer, GraphGPS
Resources and Next Steps
Papers (Recommended Reading Order)
-
GNN Foundations:
- Kipf & Welling (2017): Semi-Supervised Classification with GCNs
- Hamilton et al. (2017): GraphSAGE
- Veličković et al. (2018): Graph Attention Networks
-
ST Applications:
- SpaGCN: Nature Methods 2021
- STAGATE: Nature Communications 2022
- GraphST: Nature Communications 2023
-
Advanced:
Code Repositories
- PyTorch Geometric: https://pytorch-geometric.readthedocs.io/
- DGL (Deep Graph Library): https://www.dgl.ai/
- STAGATE: https://github.com/zhanglabtools/STAGATE
- GraphST: https://github.com/JinmiaoChenLab/GraphST
Practice Projects
- Reproduce SpaGCN on a public Visium dataset (e.g., mouse brain)
- Compare GCN vs. GAT on your own ST data—measure ARI, runtime, smoothness
- Build a hybrid graph (spatial + expression) and tune the weighting parameter
- Implement contrastive learning for multi-sample integration
Conclusion
Graph Neural Networks provide a powerful, principled framework for spatial transcriptomics analysis by:
- Respecting spatial structure through graph topology
- Learning adaptive representations via message passing
- Integrating multi-modal data (expression, coordinates, images)
- Scaling to large datasets with sampling and mini-batching
The key to success is:
- Graph construction: Choose topology that reflects your biology
- Architecture selection: Match model complexity to data size and task
- Regularization: Prevent over-smoothing and overfitting
- Evaluation: Use domain-relevant metrics and biological validation
As spatial omics technologies advance (Visium HD, Xenium, CosMx), GNNs will remain essential tools for discovering tissue organization, cell-cell interactions, and disease mechanisms.
Happy graph learning!