Back to Projects

Introduction of Graph Neural Networks for Spatial Transcriptomics

Completed
Research & Implementation
Feb 2025

A comprehensive guide to understanding and applying GNNs to spatial transcriptomics data, from foundational concepts to cutting-edge research implementations

Tech Stack

PyTorchPyTorch GeometricPythonGraph Neural NetworksSpatial Transcriptomics

Tags

deep learningbioinformaticsgraph neural networksspatial transcriptomicscomputational biology

Introduction of Graph Neural Networks for Spatial Transcriptomics

A deep dive into applying Graph Neural Networks (GNNs) to spatial transcriptomics research, bridging deep learning theory with cutting-edge biological applications.

Table of Contents

  1. Introduction
  2. From CNNs to GNNs
  3. Core GNN Architectures
  4. Graph Construction Strategies
  5. State-of-the-Art Methods in Spatial Transcriptomics
  6. Implementation Guide
  7. Practical Considerations

Introduction

Spatial transcriptomics (ST) captures gene expression while preserving the physical location of cells in tissue. This spatial context creates a natural graph structure where:

  • Nodes = spots/cells with gene expression vectors
  • Edges = spatial proximity or molecular similarity
  • Task = identify tissue domains, cell types, or spatial patterns

Graph Neural Networks provide a powerful framework to leverage this structure, enabling message-passing that respects both spatial topology and molecular relationships.


From CNNs to GNNs

Why CNNs Can't Handle Irregular Grids

Convolutional Neural Networks excel at structured data (images) because:

  1. Local connectivity: Each pixel has a fixed neighborhood (e.g., 3×3)
  2. Weight sharing: The same filter slides across all positions, making parameter efficient
  3. Translation invariance: Patterns learned anywhere apply everywhere

The convolution operation at pixel (i,j)(i,j) is:

yi,j=a=11b=11wa,bxi+a,j+by_{i,j} = \sum_{a=-1}^{1} \sum_{b=-1}^{1} w_{a,b} \cdot x_{i+a, j+b}

The problem for spatial transcriptomics: Spots are irregularly spaced and often have varying numbers of neighbors. We can't use a fixed 3×3 kernel.

Message Passing: The Key Insight

GNNs generalize convolution to arbitrary graph structures through message passing:

hi(l+1)=σ(W(l)AGGREGATE{hj(l):jN(i)})h_i^{(l+1)} = \sigma\left(W^{(l)} \cdot \text{AGGREGATE}\{h_j^{(l)} : j \in \mathcal{N}(i)\}\right)

Where:

  • hi(l)h_i^{(l)} = feature vector of node ii at layer ll
  • N(i)\mathcal{N}(i) = neighbors of node ii
  • W(l)W^{(l)} = learnable weight matrix
  • σ\sigma = activation function (ReLU, etc.)

This is conceptually similar to CNN convolution, but:

  • Neighbors are defined by edges, not spatial proximity on a grid
  • Each node can have a variable number of neighbors
  • Weights are shared across all nodes, just like CNN filters

Core GNN Architectures

1. Graph Convolutional Network (GCN)

Paper: Kipf & Welling (2017)

Core Idea: Normalize neighbor aggregation by degree to prevent scale issues.

Mathematical Formulation

Given:

  • XRN×F\mathbf{X} \in \mathbb{R}^{N \times F}: node features (N spots, F genes)
  • ARN×N\mathbf{A} \in \mathbb{R}^{N \times N}: adjacency matrix
  • D\mathbf{D}: degree matrix (diagonal, where Dii=jAijD_{ii} = \sum_j A_{ij})

Add self-loops and normalize:

A~=A+I(add self-connections)\tilde{\mathbf{A}} = \mathbf{A} + \mathbf{I} \quad \text{(add self-connections)} D~ii=jA~ij(recompute degrees)\tilde{D}_{ii} = \sum_j \tilde{A}_{ij} \quad \text{(recompute degrees)}

The GCN layer performs symmetric normalization:

H(l+1)=σ(D~1/2A~D~1/2H(l)W(l))\mathbf{H}^{(l+1)} = \sigma\left(\tilde{\mathbf{D}}^{-1/2} \tilde{\mathbf{A}} \tilde{\mathbf{D}}^{-1/2} \mathbf{H}^{(l)} \mathbf{W}^{(l)}\right)

Intuition: The symmetric normalization D~1/2A~D~1/2\tilde{\mathbf{D}}^{-1/2} \tilde{\mathbf{A}} \tilde{\mathbf{D}}^{-1/2} creates a normalized adjacency that averages features fairly across neighbors, preventing high-degree nodes from dominating.

PyTorch Implementation (Matrix Form)

import torch
import torch.nn as nn
 
def normalize_adj(edge_index, num_nodes):
    """
    Build symmetric normalized adjacency: D^{-1/2} A D^{-1/2}
    Returns sparse tensor indices and values.
    """
    i, j = edge_index
    # Add self-loops
    self_loops = torch.arange(num_nodes, device=edge_index.device)
    i = torch.cat([i, self_loops])
    j = torch.cat([j, self_loops])
 
    # Compute degree
    deg = torch.bincount(j, minlength=num_nodes).float()
    deg_inv_sqrt = deg.clamp(min=1).pow(-0.5)
 
    # Symmetric normalization weights
    vals = deg_inv_sqrt[i] * deg_inv_sqrt[j]
    return torch.stack([i, j], dim=0), vals
 
class GCNLayer(nn.Module):
    """
    Graph Convolutional Layer with sparse adjacency multiply.
    """
    def __init__(self, in_channels, out_channels, bias=True):
        super().__init__()
        self.lin = nn.Linear(in_channels, out_channels, bias=bias)
 
    def forward(self, x, edge_index):
        N = x.size(0)
        idx, vals = normalize_adj(edge_index, N)
        A_norm = torch.sparse_coo_tensor(idx, vals, (N, N), device=x.device)
 
        x = self.lin(x)                    # [N, F_out]
        x = torch.sparse.mm(A_norm, x)     # Sparse matrix multiply
        return torch.relu(x)
 
# Full GCN Model
class GCN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2):
        super().__init__()
        self.layers = nn.ModuleList()
        self.layers.append(GCNLayer(in_channels, hidden_channels))
        for _ in range(num_layers - 2):
            self.layers.append(GCNLayer(hidden_channels, hidden_channels))
        self.layers.append(GCNLayer(hidden_channels, out_channels))
 
    def forward(self, x, edge_index):
        for i, layer in enumerate(self.layers[:-1]):
            x = layer(x, edge_index)
            x = torch.relu(x)
            x = torch.dropout(x, p=0.5, train=self.training)
        x = self.layers[-1](x, edge_index)
        return x

Strengths:

  • Simple, efficient, mathematically principled
  • Fast on small-to-medium graphs
  • Smooth aggregation good for homogeneous tissue

Weaknesses:

  • Fixed (isotropic) weights—all neighbors treated equally
  • Transductive (needs full graph at training)
  • Over-smoothing with many layers

2. GraphSAGE

Paper: Hamilton et al. (2017)

Core Idea: Sample neighbors and use robust aggregators (mean, LSTM, max-pool) to scale to large graphs and enable inductive learning (generalize to unseen nodes/graphs).

Update Rule

hi(l+1)=σ(WCONCAT[hi(l),AGG{hj(l):jN(i)}])h_i^{(l+1)} = \sigma\left(\mathbf{W} \cdot \text{CONCAT}\left[h_i^{(l)}, \text{AGG}\{h_j^{(l)} : j \in \mathcal{N}(i)\}\right]\right)

Key differences from GCN:

  1. Concatenate self-feature with aggregated neighbors (preserves identity)
  2. Sample a fixed number of neighbors (e.g., 10) instead of using all
  3. Aggregator choices: mean, max-pool, LSTM

PyTorch Implementation

import torch
import torch.nn as nn
import torch.nn.functional as F
 
class SAGEConv(nn.Module):
    """
    GraphSAGE layer with mean aggregation.
    """
    def __init__(self, in_channels, out_channels, normalize=True):
        super().__init__()
        self.lin_self = nn.Linear(in_channels, out_channels)
        self.lin_neigh = nn.Linear(in_channels, out_channels)
        self.normalize = normalize
 
    def forward(self, x, edge_index, num_neighbors_sample=10):
        N = x.size(0)
 
        # Sample neighbors (for scalability)
        src, dst = edge_index
        # Group by destination
        neighbors = {}
        for s, d in zip(src.tolist(), dst.tolist()):
            neighbors.setdefault(d, []).append(s)
 
        # Sample and aggregate
        neigh_features = []
        for i in range(N):
            if i in neighbors and len(neighbors[i]) > 0:
                sampled = torch.tensor(
                    neighbors[i][:num_neighbors_sample],
                    device=x.device
                )
                neigh_features.append(x[sampled].mean(dim=0))
            else:
                neigh_features.append(torch.zeros(x.size(1), device=x.device))
 
        neigh_agg = torch.stack(neigh_features)  # [N, F]
 
        # Combine self and neighbor
        out = self.lin_self(x) + self.lin_neigh(neigh_agg)
 
        if self.normalize:
            out = F.normalize(out, p=2, dim=1)
 
        return out
 
class GraphSAGE(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2):
        super().__init__()
        self.layers = nn.ModuleList()
        self.layers.append(SAGEConv(in_channels, hidden_channels))
        for _ in range(num_layers - 2):
            self.layers.append(SAGEConv(hidden_channels, hidden_channels))
        self.layers.append(SAGEConv(hidden_channels, out_channels))
 
    def forward(self, x, edge_index):
        for layer in self.layers[:-1]:
            x = layer(x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=0.5, training=self.training)
        x = self.layers[-1](x, edge_index)
        return x

Strengths:

  • Scales to massive graphs via neighbor sampling
  • Inductive: can generalize to new slides without retraining
  • Multiple aggregator options

Weaknesses:

  • Sampling introduces variance
  • More hyperparameters (sample size per layer)
  • Still treats all sampled neighbors equally

3. Graph Attention Network (GAT)

Paper: Veličković et al. (2018)

Core Idea: Learn per-edge attention weights so the model can focus on important neighbors and ignore noisy ones.

Attention Mechanism

For each edge (i,j)(i, j):

eij=LeakyReLU(aT[WhiWhj])e_{ij} = \text{LeakyReLU}\left(\mathbf{a}^T [\mathbf{W} h_i \,||\, \mathbf{W} h_j]\right)

Normalize with softmax over all neighbors:

αij=exp(eij)kN(i)exp(eik)\alpha_{ij} = \frac{\exp(e_{ij})}{\sum_{k \in \mathcal{N}(i)} \exp(e_{ik})}

Aggregate with learned weights:

hi(l+1)=σ(jN(i)αijWhj(l))h_i^{(l+1)} = \sigma\left(\sum_{j \in \mathcal{N}(i)} \alpha_{ij} \cdot \mathbf{W} \cdot h_j^{(l)}\right)

Multi-head attention (like Transformers): Run KK independent attention heads and concatenate:

hi(l+1)=k=1Kσ(jN(i)αijkWkhj(l))h_i^{(l+1)} = \Big|\Big|_{k=1}^K \sigma\left(\sum_{j \in \mathcal{N}(i)} \alpha_{ij}^k \cdot \mathbf{W}^k \cdot h_j^{(l)}\right)

PyTorch Implementation

import torch
import torch.nn as nn
import torch.nn.functional as F
 
class GATLayer(nn.Module):
    """
    Graph Attention Layer (single head).
    """
    def __init__(self, in_channels, out_channels, dropout=0.6, alpha=0.2):
        super().__init__()
        self.W = nn.Linear(in_channels, out_channels, bias=False)
        self.a = nn.Parameter(torch.zeros(size=(2 * out_channels, 1)))
        self.leakyrelu = nn.LeakyReLU(alpha)
        self.dropout = dropout
        nn.init.xavier_uniform_(self.a.data, gain=1.414)
 
    def forward(self, x, edge_index):
        # Linear transformation
        Wh = self.W(x)  # [N, out_channels]
        N = Wh.size(0)
 
        # Compute attention coefficients
        src, dst = edge_index
        # Concatenate source and destination features
        Wh_src = Wh[src]  # [E, out_channels]
        Wh_dst = Wh[dst]  # [E, out_channels]
        concat_features = torch.cat([Wh_src, Wh_dst], dim=1)  # [E, 2*out_channels]
 
        # Attention mechanism
        e = self.leakyrelu(concat_features @ self.a).squeeze()  # [E]
 
        # Normalize per destination node
        attention = torch.zeros(N, device=x.device)
        attention_weights = torch.zeros_like(e)
 
        for i in range(N):
            mask = (dst == i)
            if mask.sum() > 0:
                e_neighbors = e[mask]
                alpha = F.softmax(e_neighbors, dim=0)
                attention_weights[mask] = alpha
 
        # Apply dropout to attention
        attention_weights = F.dropout(attention_weights, self.dropout, training=self.training)
 
        # Aggregate
        out = torch.zeros_like(Wh)
        for i in range(N):
            mask = (dst == i)
            if mask.sum() > 0:
                out[i] = (attention_weights[mask].unsqueeze(1) * Wh[src[mask]]).sum(dim=0)
 
        return out
 
class GAT(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels,
                 num_heads=8, dropout=0.6):
        super().__init__()
        self.dropout = dropout
 
        # Multi-head first layer
        self.attentions = nn.ModuleList([
            GATLayer(in_channels, hidden_channels, dropout=dropout)
            for _ in range(num_heads)
        ])
 
        # Single-head output layer
        self.out_att = GATLayer(
            hidden_channels * num_heads,
            out_channels,
            dropout=dropout
        )
 
    def forward(self, x, edge_index):
        # First layer: concatenate multi-head outputs
        x = torch.cat([att(x, edge_index) for att in self.attentions], dim=1)
        x = F.dropout(x, self.dropout, training=self.training)
        x = F.elu(x)
 
        # Output layer
        x = self.out_att(x, edge_index)
        return x

Strengths:

  • Learns which neighbors are important (anisotropic)
  • Better at handling heterogeneous neighborhoods
  • Respects sharp boundaries (e.g., tumor-stroma)

Weaknesses:

  • More parameters (attention + multi-heads)
  • Slower than GCN
  • Can be unstable (needs careful tuning of dropout, LR)

Architecture Comparison Table

ModelNeighbor WeightsSamplingInductive?Best Use Case
GCNFixed by normalized adjacencyFull graph (no sampling)No (transductive)Smooth tissue, clean spatial graphs, small-medium datasets
GraphSAGEFixed aggregator (mean/max/LSTM)Neighborhood samplingYesLarge graphs, multiple slides, production systems
GATLearned per-edge attentionFull neighborhood (or can sample)Often yesHeterogeneous tissue, sharp boundaries, variable neighbor quality

Visual Comparison:

GCN:      All neighbors weighted equally (by degree normalization)
          [Spot] ← (0.25, 0.25, 0.25, 0.25) ← [4 neighbors]

GraphSAGE: Sample subset, aggregate with mean/max
          [Spot] ← sample 2 → mean([n1, n2])

GAT:      Learn attention per edge
          [Spot] ← (α₁=0.5, α₂=0.3, α₃=0.15, α₄=0.05) ← [4 neighbors]
                    ↑ Learned weights

Graph Construction Strategies

Critical insight: Graph topology often matters more than the GNN architecture.

1. Spatial Proximity Graphs

k-Nearest Neighbors (kNN)

Connect each spot to its k closest spatial neighbors.

import numpy as np
from sklearn.neighbors import NearestNeighbors
import torch
 
def knn_graph(coords, k=6):
    """
    Build kNN graph from spatial coordinates.
 
    Args:
        coords: [N, 2] array of (x, y) positions
        k: number of neighbors
    Returns:
        edge_index: [2, E] tensor in COO format
    """
    nbrs = NearestNeighbors(n_neighbors=k+1, algorithm='kd_tree').fit(coords)
    distances, indices = nbrs.kneighbors(coords)
 
    # Exclude self (first neighbor)
    src = np.repeat(np.arange(len(coords)), k)
    dst = indices[:, 1:].reshape(-1)
 
    # Make undirected
    edge_index = np.stack([src, dst], axis=0)
    edge_index = np.concatenate([edge_index, edge_index[::-1]], axis=1)
    edge_index = np.unique(edge_index, axis=1)
 
    return torch.as_tensor(edge_index, dtype=torch.long)

Pros: Simple, controllable density Cons: Fixed k ignores tissue density variation

Effect of k on smoothing:

  • Small k (4-6): Preserves sharp boundaries, less smoothing
  • Large k (12-20): More smoothing, boundaries blur, risk of over-smoothing

Radius (ε-ball) Graph

Connect all spots within distance ε.

def radius_graph(coords, radius=100.0):
    """
    Build radius graph: connect nodes within distance threshold.
    """
    from sklearn.neighbors import radius_neighbors_graph
    A = radius_neighbors_graph(coords, radius, mode='connectivity')
    edge_index = torch.as_tensor(np.array(A.nonzero()), dtype=torch.long)
    return edge_index

Pros: Density-adaptive, physically meaningful Cons: Can create disconnected components if ε too small

Delaunay Triangulation

Build triangles connecting nearest points without overlap—preserves natural geometry.

from scipy.spatial import Delaunay
import numpy as np
 
def delaunay_graph(coords):
    """
    Build graph from Delaunay triangulation.
    Used in SpaGCN for natural spatial connectivity.
    """
    tri = Delaunay(coords)
    edges = set()
    for simplex in tri.simplices:
        for i in range(3):
            edge = tuple(sorted([simplex[i], simplex[(i+1)%3]]))
            edges.add(edge)
 
    edges = np.array(list(edges)).T
    # Make undirected
    edge_index = np.concatenate([edges, edges[::-1]], axis=1)
    return torch.as_tensor(edge_index, dtype=torch.long)

Pros: Natural geometric structure, used in SpaGCN Cons: Can create long edges in sparse regions


2. Expression Similarity Graphs

Connect spots with similar gene expression, independent of distance.

def expression_similarity_graph(X, k=10, metric='cosine'):
    """
    Build kNN graph in expression space.
 
    Args:
        X: [N, genes] expression matrix
        k: number of neighbors
        metric: 'cosine' or 'euclidean'
    """
    from sklearn.neighbors import NearestNeighbors
    nbrs = NearestNeighbors(n_neighbors=k+1, metric=metric).fit(X)
    distances, indices = nbrs.kneighbors(X)
 
    src = np.repeat(np.arange(len(X)), k)
    dst = indices[:, 1:].reshape(-1)
    edge_index = np.stack([src, dst], axis=0)
 
    # Undirected
    edge_index = np.concatenate([edge_index, edge_index[::-1]], axis=1)
    edge_index = np.unique(edge_index, axis=1)
 
    return torch.as_tensor(edge_index, dtype=torch.long)

Pros: Captures functional/molecular similarity Cons: May connect distant regions; ignores spatial structure


3. Hybrid (Multi-view) Graphs

Combine spatial and expression information.

Weighted combination:

A=αAspatial+(1α)Aexpression\mathbf{A} = \alpha \cdot \mathbf{A}_{\text{spatial}} + (1 - \alpha) \cdot \mathbf{A}_{\text{expression}}

Learnable weighting:

class MultiViewGNN(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.gcn_spatial = GCNLayer(in_channels, out_channels)
        self.gcn_expression = GCNLayer(in_channels, out_channels)
        self.alpha = nn.Parameter(torch.tensor(0.5))  # Learnable weight
 
    def forward(self, x, edge_index_spatial, edge_index_expr):
        h_spatial = self.gcn_spatial(x, edge_index_spatial)
        h_expr = self.gcn_expression(x, edge_index_expr)
 
        alpha = torch.sigmoid(self.alpha)  # Ensure [0,1]
        return alpha * h_spatial + (1 - alpha) * h_expr

Used in: STAGATE, conST, GraphST


4. Image-Based Graphs (Histology Integration)

If H&E images are available, use morphological similarity.

def image_weighted_graph(coords, image_features, k=6, alpha=0.5):
    """
    Weighted graph combining spatial distance and image feature similarity.
 
    Args:
        coords: [N, 2] spatial positions
        image_features: [N, D] CNN features from histology patches
        alpha: weight between spatial and image (0=spatial only, 1=image only)
    """
    from sklearn.neighbors import NearestNeighbors
    from sklearn.metrics.pairwise import cosine_similarity
 
    # Spatial kNN
    nbrs_spatial = NearestNeighbors(n_neighbors=k+1).fit(coords)
    dist_spatial, idx_spatial = nbrs_spatial.kneighbors(coords)
 
    # Image similarity
    img_sim = cosine_similarity(image_features)
 
    edges = []
    weights = []
    for i in range(len(coords)):
        for j_idx, j in enumerate(idx_spatial[i, 1:]):  # Skip self
            spatial_sim = 1.0 / (1.0 + dist_spatial[i, j_idx+1])
            image_sim_val = img_sim[i, j]
 
            combined_weight = alpha * spatial_sim + (1 - alpha) * image_sim_val
            edges.append([i, j])
            weights.append(combined_weight)
 
    edge_index = torch.tensor(edges, dtype=torch.long).T
    edge_weight = torch.tensor(weights, dtype=torch.float)
 
    return edge_index, edge_weight

Used in: SpaGCN, conST, MERGE


State-of-the-Art Methods in Spatial Transcriptomics

Method Comparison Table

MethodYearGraph TypeGNN ArchitectureKey InnovationTask
SpaGCN2021Spatial + histology weightedGCNDelaunay + image features, domain detectionClustering, SVG detection
STAGATE2022Spatial kNNGraph Attention AutoencoderLearned attention, autoencoder frameworkClustering, denoising, trajectory
SEDR2021Spatial kNNVGAE + deep autoencoderDual embedding (expression + spatial VGAE)Clustering, embedding
GraphST2023Multi-view (spatial + expr)GCN + contrastive learningSelf-supervised contrastive, multi-sample integrationClustering, deconvolution
STdGCN2023Hybrid (spot + pseudo-spot)GCNIntegrate scRNA reference via hybrid graphDeconvolution
stAA2024Spatial kNNAdversarial VGAEWGAN regularization on embeddingsClustering
stMCDI2024SpatialGNN encoder in diffusion modelDiffusion-based imputation with GNNMissing value imputation
MERGE2023Hierarchical image patchesMulti-scale GNNImage → gene prediction via hierarchical graphPrediction from histology

Deep Dive: SpaGCN

Paper: Hu et al., Nature Methods (2021)

Problem: Identify spatial domains using both gene expression and tissue morphology.

Graph Construction:

  1. Build Delaunay triangulation from spot coordinates
  2. Weight edges by combination of:
    • Spatial distance
    • Histology image similarity (compare H&E patches around each spot)
wij=exp(dspatial(i,j)22σd2)×exp(dimage(i,j)22σi2)w_{ij} = \exp\left(-\frac{d_{\text{spatial}}(i,j)^2}{2\sigma_d^2}\right) \times \exp\left(-\frac{d_{\text{image}}(i,j)^2}{2\sigma_i^2}\right)

GNN: Standard GCN with weighted adjacency

Loss: Combines:

  • Clustering loss (Louvain or k-means on embeddings)
  • Spatial regularization (neighbors should have similar cluster assignments)

Code Sketch:

# Weighted adjacency from spatial + image
A_weighted = spatial_weight * image_weight
 
# GCN embedding
H = gcn_layers(X, A_weighted)
 
# Clustering
clusters = louvain_clustering(H)
 
# Find spatially variable genes per domain
svg_genes = differential_expression_per_domain(clusters)

Strengths: Multi-modal, interpretable, works well on Visium data Limitations: Static graph, sensitive to weighting hyperparameters


Deep Dive: STAGATE

Paper: Dong & Chen, Nature Communications (2022)

Problem: Learn representations robust to noise, preserve spatial structure, enable downstream tasks (clustering, pseudotime).

Architecture: Graph Attention Autoencoder

  • Encoder: Multi-layer GAT → latent embedding Z
  • Decoder: Reconstruct both expression X and adjacency A

Key Innovation: Attention mechanism learns which neighbors are relevant, mitigating over-smoothing.

Loss:

L=Lrecon(X,X^)+λLgraph(A,A^)\mathcal{L} = \mathcal{L}_{\text{recon}}(\mathbf{X}, \hat{\mathbf{X}}) + \lambda \cdot \mathcal{L}_{\text{graph}}(\mathbf{A}, \hat{\mathbf{A}})

Where:

  • Lrecon\mathcal{L}_{\text{recon}} = MSE for expression reconstruction
  • Lgraph\mathcal{L}_{\text{graph}} = binary cross-entropy for adjacency reconstruction

Code Outline:

class STAGATE(nn.Module):
    def __init__(self, n_genes, hidden_dim, latent_dim):
        super().__init__()
        self.encoder = nn.ModuleList([
            GATLayer(n_genes, hidden_dim),
            GATLayer(hidden_dim, latent_dim)
        ])
        self.decoder_expr = nn.Linear(latent_dim, n_genes)
        self.decoder_graph = InnerProductDecoder()
 
    def forward(self, x, edge_index):
        # Encode
        h = x
        for layer in self.encoder:
            h = F.elu(layer(h, edge_index))
 
        # Decode
        x_recon = self.decoder_expr(h)
        adj_recon = torch.sigmoid(h @ h.T)  # Inner product for adjacency
 
        return x_recon, adj_recon, h

Strengths: Adaptive smoothing, rich embeddings for multiple tasks Limitations: More hyperparameters, potential overfitting on small datasets


Deep Dive: GraphST (Contrastive Learning)

Paper: Long et al., Nature Communications (2023)

Problem: Integrate multiple tissue sections, perform clustering and deconvolution.

Key Idea: Self-supervised contrastive learning on graph embeddings.

Method:

  1. Build two views of the same graph (e.g., different augmentations or spatial vs. expression graphs)
  2. Encode both with GCN
  3. Contrastive loss: pull together embeddings from same spot, push apart different spots
Lcontrast=logexp(sim(zi,zi+)/τ)jexp(sim(zi,zj)/τ)\mathcal{L}_{\text{contrast}} = -\log\frac{\exp(\text{sim}(z_i, z_i^+) / \tau)}{\sum_j \exp(\text{sim}(z_i, z_j) / \tau)}

Where zi+z_i^+ is the augmented view of spot ii.

Benefits:

  • Learns robust representations without labeled data
  • Aligns multiple slides in shared embedding space

Code Sketch:

# Create two views
edge_index_view1 = knn_graph(coords, k=6)
edge_index_view2 = expression_similarity_graph(X, k=10)
 
# Encode both
z1 = gnn_encoder(X, edge_index_view1)
z2 = gnn_encoder(X, edge_index_view2)
 
# Contrastive loss (InfoNCE)
loss = contrastive_loss(z1, z2, temperature=0.5)

Implementation Guide

Complete Pipeline: Spatial Domain Detection

Here's a complete, runnable pipeline from raw Visium data to domain detection using a GNN.

Step 1: Load Data

import scanpy as sc
import numpy as np
import torch
from sklearn.preprocessing import StandardScaler
 
# Load Visium data (standard scanpy format)
adata = sc.read_visium("path/to/visium/data")
 
# Extract coordinates and expression
coords = adata.obsm['spatial']  # [N, 2]
X = adata.X.toarray() if hasattr(adata.X, 'toarray') else adata.X  # [N, genes]
 
# Highly variable genes
sc.pp.highly_variable_genes(adata, n_top_genes=2000)
X = X[:, adata.var['highly_variable']]
 
# Normalize
X = StandardScaler().fit_transform(X)
 
# To tensors
X_tensor = torch.tensor(X, dtype=torch.float32)
coords_np = np.array(coords)

Step 2: Build Graph

from sklearn.neighbors import NearestNeighbors
 
def build_spatial_graph(coords, k=6):
    nbrs = NearestNeighbors(n_neighbors=k+1).fit(coords)
    _, indices = nbrs.kneighbors(coords)
 
    src = np.repeat(np.arange(len(coords)), k)
    dst = indices[:, 1:].reshape(-1)
    edge_index = np.stack([src, dst], axis=0)
 
    # Undirected
    edge_index = np.concatenate([edge_index, edge_index[::-1]], axis=1)
    edge_index = np.unique(edge_index, axis=1)
 
    return torch.tensor(edge_index, dtype=torch.long)
 
edge_index = build_spatial_graph(coords_np, k=6)

Step 3: Define Model

class SpatialGNN(nn.Module):
    def __init__(self, n_genes, hidden_dim=128, latent_dim=32, n_clusters=7):
        super().__init__()
        self.encoder = nn.ModuleList([
            GCNLayer(n_genes, hidden_dim),
            GCNLayer(hidden_dim, latent_dim)
        ])
        self.cluster_head = nn.Linear(latent_dim, n_clusters)
 
    def forward(self, x, edge_index):
        h = x
        for layer in self.encoder:
            h = layer(h, edge_index)
            h = F.relu(h)
            h = F.dropout(h, p=0.3, training=self.training)
 
        logits = self.cluster_head(h)
        return h, logits
 
model = SpatialGNN(n_genes=X.shape[1], n_clusters=7)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

Step 4: Training Loop

from sklearn.cluster import KMeans
 
def train_epoch(model, x, edge_index, optimizer):
    model.train()
    optimizer.zero_grad()
 
    # Forward
    embeddings, logits = model(x, edge_index)
 
    # Pseudo-labels from k-means (self-supervised)
    with torch.no_grad():
        kmeans = KMeans(n_clusters=7, random_state=42)
        pseudo_labels = kmeans.fit_predict(embeddings.cpu().numpy())
        pseudo_labels = torch.tensor(pseudo_labels, dtype=torch.long)
 
    # Cross-entropy loss
    loss = F.cross_entropy(logits, pseudo_labels)
 
    loss.backward()
    optimizer.step()
 
    return loss.item()
 
# Train
for epoch in range(100):
    loss = train_epoch(model, X_tensor, edge_index, optimizer)
    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {loss:.4f}")

Step 5: Extract Clusters and Visualize

model.eval()
with torch.no_grad():
    embeddings, logits = model(X_tensor, edge_index)
    clusters = logits.argmax(dim=1).cpu().numpy()
 
# Add to adata
adata.obs['gnn_cluster'] = clusters
 
# Visualize on tissue
import scanpy as sc
sc.pl.spatial(adata, color='gnn_cluster', spot_size=150)

Practical Considerations

1. Over-smoothing Problem

Issue: With many GNN layers, all node features converge to the same value.

Why: Each layer mixes neighbors' features, leading to exponential diffusion.

Solutions:

  • Limit layers: 2-3 layers often enough for ST (tissue neighborhoods are local)

  • Skip connections: Add residual connections

    h_new = gcn_layer(h) + h  # Residual
  • Adaptive depth (DropEdge, PairNorm): Randomly drop edges or normalize layer outputs

Rule of thumb for ST:

  • Small tissue sections: 2 layers
  • Whole-slide: 3-4 layers max
  • Multi-sample integration: Consider global pooling instead of deeper GNN

2. Choosing Hyperparameter k (Number of Neighbors)

Trade-off:

  • Low k (4-6): Sharp boundaries, less smoothing, more noise sensitivity
  • High k (12-20): Smoother, more stable, but blurs boundaries

Guideline:

  • Delaunay: Natural choice, typically k6k \approx 6
  • Brain cortex / layered tissue: Lower k to preserve layers
  • Tumor heterogeneity: Mid-range k (6-8)
  • Dense spot arrays (Visium HD): Higher k (10-15)

Experiment: Try k{4,6,8,10,12}k \in \{4, 6, 8, 10, 12\} and check:

  • Silhouette score on clusters
  • Agreement with known anatomical regions

3. Normalization and Preprocessing

Critical steps:

  1. Gene selection: Use highly variable genes (1000-3000)
  2. Expression normalization: Log1p + standard scaling
  3. Graph normalization: Symmetric D~1/2A~D~1/2\tilde{\mathbf{D}}^{-1/2} \tilde{\mathbf{A}} \tilde{\mathbf{D}}^{-1/2} (built into GCN)
  4. Feature scaling: StandardScaler before training
import scanpy as sc
from sklearn.preprocessing import StandardScaler
 
# Standard pipeline
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
sc.pp.highly_variable_genes(adata, n_top_genes=2000)
adata = adata[:, adata.var.highly_variable]
 
X = StandardScaler().fit_transform(adata.X.toarray())

4. Evaluation Metrics

For clustering:

  • Adjusted Rand Index (ARI): Compare with ground-truth annotations
  • Normalized Mutual Information (NMI)
  • Silhouette score: Cluster separation in embedding space

For deconvolution:

  • RMSE vs. known proportions (if synthetic data)
  • Correlation with reference cell-type markers

For imputation:

  • RMSE on held-out genes
  • Pearson correlation with ground truth
from sklearn.metrics import adjusted_rand_score, silhouette_score
 
# Clustering evaluation
ari = adjusted_rand_score(true_labels, predicted_clusters)
silhouette = silhouette_score(embeddings, predicted_clusters)
 
print(f"ARI: {ari:.3f}, Silhouette: {silhouette:.3f}")

5. Common Pitfalls

ProblemSymptomSolution
Over-smoothingAll embeddings identicalReduce layers, add skip connections
Disconnected graphNaN in trainingCheck graph connectivity, increase k or ε\varepsilon
Unstable GATLoss spikesLower learning rate, increase dropout on attention
Poor clusteringRandom assignmentsBetter graph construction, tune k, check normalization
High memoryOOM errorsUse mini-batching (GraphSAGE), neighbor sampling

Advanced Topics

1. Hypergraph Neural Networks

Standard GNN: Pairwise edges (connects 2 nodes) Hypergraph: Hyperedges connect 2\geq 2 nodes simultaneously

Use case in ST: Group of spots in a tissue domain share a hyperedge

Example: Hypergraph Neural Networks Reveal Spatial Domains

2. Hierarchical and Multi-Scale Graphs

Idea: Build graphs at multiple resolutions

  • Fine-scale: spot-to-spot
  • Coarse-scale: domain-to-domain (via clustering)

Benefits: Capture both local and global patterns

Example: MERGE builds hierarchical image patch graphs

3. Graph Transformers

Replace message-passing with full self-attention over all nodes.

Pros: No over-smoothing, long-range dependencies Cons: O(N2)\mathcal{O}(N^2) complexity

Emerging work: Graphormer, GraphGPS


Resources and Next Steps

  1. GNN Foundations:

  2. ST Applications:

  3. Advanced:

Code Repositories

Practice Projects

  1. Reproduce SpaGCN on a public Visium dataset (e.g., mouse brain)
  2. Compare GCN vs. GAT on your own ST data—measure ARI, runtime, smoothness
  3. Build a hybrid graph (spatial + expression) and tune the weighting parameter α\alpha
  4. Implement contrastive learning for multi-sample integration

Conclusion

Graph Neural Networks provide a powerful, principled framework for spatial transcriptomics analysis by:

  • Respecting spatial structure through graph topology
  • Learning adaptive representations via message passing
  • Integrating multi-modal data (expression, coordinates, images)
  • Scaling to large datasets with sampling and mini-batching

The key to success is:

  1. Graph construction: Choose topology that reflects your biology
  2. Architecture selection: Match model complexity to data size and task
  3. Regularization: Prevent over-smoothing and overfitting
  4. Evaluation: Use domain-relevant metrics and biological validation

As spatial omics technologies advance (Visium HD, Xenium, CosMx), GNNs will remain essential tools for discovering tissue organization, cell-cell interactions, and disease mechanisms.


Happy graph learning!

You Might Also Like