152 lines
6.3 KiB
Python
152 lines
6.3 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
VGAE encoder architectures for chromatin contact graphs.
|
|
|
|
Exported symbols
|
|
----------------
|
|
GCNEncoder — original 2-layer GCN (kept for backward compatibility)
|
|
GATEncoder — 2-layer GATv2 with multi-head attention
|
|
DeepGCNEncoder — 3-layer GCN with residual BatchNorm between layers
|
|
Encoder — alias for GCNEncoder (backward compat)
|
|
build_encoder() — factory: returns the right class from a string name
|
|
"""
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch_geometric.nn import GCNConv, GATv2Conv
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# GCN encoder (baseline)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class GCNEncoder(nn.Module):
|
|
"""Two-layer GCN encoder with input BatchNorm.
|
|
|
|
Architecture: BatchNorm → GCNConv(hidden) → ReLU → Dropout
|
|
→ GCNConv_mu / GCNConv_logstd
|
|
"""
|
|
|
|
def __init__(self, in_dim: int, hidden: int, latent: int, dropout: float = 0.2, **_):
|
|
super().__init__()
|
|
self.norm = nn.BatchNorm1d(in_dim)
|
|
self.gc1 = GCNConv(in_dim, hidden, add_self_loops=True, normalize=True)
|
|
self.gc_mu = GCNConv(hidden, latent, add_self_loops=True, normalize=True)
|
|
self.gc_log = GCNConv(hidden, latent, add_self_loops=True, normalize=True)
|
|
self.dropout = dropout
|
|
|
|
def forward(self, x, edge_index, edge_weight=None):
|
|
x = self.norm(x)
|
|
h = F.relu(self.gc1(x, edge_index, edge_weight))
|
|
h = F.dropout(h, p=self.dropout, training=self.training)
|
|
return self.gc_mu(h, edge_index, edge_weight), self.gc_log(h, edge_index, edge_weight)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# GAT encoder (preferred for Hi-C: handles degree heterogeneity via attention)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class GATEncoder(nn.Module):
|
|
"""Two-layer GATv2 encoder.
|
|
|
|
Each GATv2 layer applies multi-head attention, which lets the model
|
|
up-weight high-frequency contacts at TAD boundaries and CTCF anchors
|
|
rather than averaging all neighbours uniformly (as GCN does).
|
|
|
|
Architecture: BatchNorm → GATv2(hidden, heads) → ELU → BN → Dropout
|
|
→ GATv2(hidden, heads) → Dropout
|
|
→ GCNConv_mu / GCNConv_logstd
|
|
"""
|
|
|
|
def __init__(self, in_dim: int, hidden: int, latent: int,
|
|
heads: int = 4, dropout: float = 0.2, **_):
|
|
super().__init__()
|
|
if hidden % heads != 0:
|
|
raise ValueError(f"hidden ({hidden}) must be divisible by heads ({heads})")
|
|
self.norm = nn.BatchNorm1d(in_dim)
|
|
self.gat1 = GATv2Conv(in_dim, hidden // heads, heads=heads,
|
|
dropout=dropout, add_self_loops=True, concat=True)
|
|
self.bn1 = nn.BatchNorm1d(hidden)
|
|
self.gat2 = GATv2Conv(hidden, hidden // heads, heads=heads,
|
|
dropout=dropout, add_self_loops=True, concat=True)
|
|
self.gc_mu = GCNConv(hidden, latent, add_self_loops=True, normalize=True)
|
|
self.gc_log = GCNConv(hidden, latent, add_self_loops=True, normalize=True)
|
|
self.dropout = dropout
|
|
|
|
def forward(self, x, edge_index, edge_weight=None):
|
|
x = self.norm(x)
|
|
h = F.elu(self.gat1(x, edge_index))
|
|
h = self.bn1(h)
|
|
h = F.dropout(h, p=self.dropout, training=self.training)
|
|
h = F.elu(self.gat2(h, edge_index))
|
|
h = F.dropout(h, p=self.dropout, training=self.training)
|
|
# GATv2 learns its own attention weights; edge_weight is used only in the
|
|
# final linear projection layers (mu/log) where GCNConv accepts it.
|
|
return self.gc_mu(h, edge_index, edge_weight), self.gc_log(h, edge_index, edge_weight)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Deep GCN encoder (3 message-passing layers)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class DeepGCNEncoder(nn.Module):
|
|
"""Three-layer GCN encoder — wider receptive field than the baseline.
|
|
|
|
Architecture: BatchNorm → GCN1 → BN → ReLU → Dropout
|
|
→ GCN2 → ReLU → Dropout
|
|
→ GCNConv_mu / GCNConv_logstd
|
|
"""
|
|
|
|
def __init__(self, in_dim: int, hidden: int, latent: int, dropout: float = 0.2, **_):
|
|
super().__init__()
|
|
self.norm = nn.BatchNorm1d(in_dim)
|
|
self.gc1 = GCNConv(in_dim, hidden, add_self_loops=True, normalize=True)
|
|
self.bn1 = nn.BatchNorm1d(hidden)
|
|
self.gc2 = GCNConv(hidden, hidden, add_self_loops=True, normalize=True)
|
|
self.gc_mu = GCNConv(hidden, latent, add_self_loops=True, normalize=True)
|
|
self.gc_log = GCNConv(hidden, latent, add_self_loops=True, normalize=True)
|
|
self.dropout = dropout
|
|
|
|
def forward(self, x, edge_index, edge_weight=None):
|
|
x = self.norm(x)
|
|
h = F.relu(self.gc1(x, edge_index, edge_weight))
|
|
h = self.bn1(h)
|
|
h = F.dropout(h, p=self.dropout, training=self.training)
|
|
h = F.relu(self.gc2(h, edge_index, edge_weight))
|
|
h = F.dropout(h, p=self.dropout, training=self.training)
|
|
return self.gc_mu(h, edge_index, edge_weight), self.gc_log(h, edge_index, edge_weight)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Backward compatibility alias
|
|
# ---------------------------------------------------------------------------
|
|
|
|
Encoder = GCNEncoder
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Factory
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_ENCODERS = {
|
|
"gcn": GCNEncoder,
|
|
"gat": GATEncoder,
|
|
"deep_gcn": DeepGCNEncoder,
|
|
}
|
|
|
|
|
|
def build_encoder(name: str, in_dim: int, hidden: int, latent: int, **kwargs) -> nn.Module:
|
|
"""Instantiate an encoder by name.
|
|
|
|
Parameters
|
|
----------
|
|
name : {"gcn", "gat", "deep_gcn"}
|
|
in_dim, hidden, latent : layer dimensions
|
|
**kwargs : passed to the constructor (e.g. dropout=0.3, heads=8)
|
|
"""
|
|
name = name.lower()
|
|
if name not in _ENCODERS:
|
|
raise ValueError(f"Unknown encoder '{name}'. Choose from {list(_ENCODERS)}")
|
|
return _ENCODERS[name](in_dim=in_dim, hidden=hidden, latent=latent, **kwargs)
|