Files

152 lines
6.3 KiB
Python

#!/usr/bin/env python3
"""
VGAE encoder architectures for chromatin contact graphs.
Exported symbols
----------------
GCNEncoder — original 2-layer GCN (kept for backward compatibility)
GATEncoder — 2-layer GATv2 with multi-head attention
DeepGCNEncoder — 3-layer GCN with residual BatchNorm between layers
Encoder — alias for GCNEncoder (backward compat)
build_encoder() — factory: returns the right class from a string name
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATv2Conv
# ---------------------------------------------------------------------------
# GCN encoder (baseline)
# ---------------------------------------------------------------------------
class GCNEncoder(nn.Module):
"""Two-layer GCN encoder with input BatchNorm.
Architecture: BatchNorm → GCNConv(hidden) → ReLU → Dropout
→ GCNConv_mu / GCNConv_logstd
"""
def __init__(self, in_dim: int, hidden: int, latent: int, dropout: float = 0.2, **_):
super().__init__()
self.norm = nn.BatchNorm1d(in_dim)
self.gc1 = GCNConv(in_dim, hidden, add_self_loops=True, normalize=True)
self.gc_mu = GCNConv(hidden, latent, add_self_loops=True, normalize=True)
self.gc_log = GCNConv(hidden, latent, add_self_loops=True, normalize=True)
self.dropout = dropout
def forward(self, x, edge_index, edge_weight=None):
x = self.norm(x)
h = F.relu(self.gc1(x, edge_index, edge_weight))
h = F.dropout(h, p=self.dropout, training=self.training)
return self.gc_mu(h, edge_index, edge_weight), self.gc_log(h, edge_index, edge_weight)
# ---------------------------------------------------------------------------
# GAT encoder (preferred for Hi-C: handles degree heterogeneity via attention)
# ---------------------------------------------------------------------------
class GATEncoder(nn.Module):
"""Two-layer GATv2 encoder.
Each GATv2 layer applies multi-head attention, which lets the model
up-weight high-frequency contacts at TAD boundaries and CTCF anchors
rather than averaging all neighbours uniformly (as GCN does).
Architecture: BatchNorm → GATv2(hidden, heads) → ELU → BN → Dropout
→ GATv2(hidden, heads) → Dropout
→ GCNConv_mu / GCNConv_logstd
"""
def __init__(self, in_dim: int, hidden: int, latent: int,
heads: int = 4, dropout: float = 0.2, **_):
super().__init__()
if hidden % heads != 0:
raise ValueError(f"hidden ({hidden}) must be divisible by heads ({heads})")
self.norm = nn.BatchNorm1d(in_dim)
self.gat1 = GATv2Conv(in_dim, hidden // heads, heads=heads,
dropout=dropout, add_self_loops=True, concat=True)
self.bn1 = nn.BatchNorm1d(hidden)
self.gat2 = GATv2Conv(hidden, hidden // heads, heads=heads,
dropout=dropout, add_self_loops=True, concat=True)
self.gc_mu = GCNConv(hidden, latent, add_self_loops=True, normalize=True)
self.gc_log = GCNConv(hidden, latent, add_self_loops=True, normalize=True)
self.dropout = dropout
def forward(self, x, edge_index, edge_weight=None):
x = self.norm(x)
h = F.elu(self.gat1(x, edge_index))
h = self.bn1(h)
h = F.dropout(h, p=self.dropout, training=self.training)
h = F.elu(self.gat2(h, edge_index))
h = F.dropout(h, p=self.dropout, training=self.training)
# GATv2 learns its own attention weights; edge_weight is used only in the
# final linear projection layers (mu/log) where GCNConv accepts it.
return self.gc_mu(h, edge_index, edge_weight), self.gc_log(h, edge_index, edge_weight)
# ---------------------------------------------------------------------------
# Deep GCN encoder (3 message-passing layers)
# ---------------------------------------------------------------------------
class DeepGCNEncoder(nn.Module):
"""Three-layer GCN encoder — wider receptive field than the baseline.
Architecture: BatchNorm → GCN1 → BN → ReLU → Dropout
→ GCN2 → ReLU → Dropout
→ GCNConv_mu / GCNConv_logstd
"""
def __init__(self, in_dim: int, hidden: int, latent: int, dropout: float = 0.2, **_):
super().__init__()
self.norm = nn.BatchNorm1d(in_dim)
self.gc1 = GCNConv(in_dim, hidden, add_self_loops=True, normalize=True)
self.bn1 = nn.BatchNorm1d(hidden)
self.gc2 = GCNConv(hidden, hidden, add_self_loops=True, normalize=True)
self.gc_mu = GCNConv(hidden, latent, add_self_loops=True, normalize=True)
self.gc_log = GCNConv(hidden, latent, add_self_loops=True, normalize=True)
self.dropout = dropout
def forward(self, x, edge_index, edge_weight=None):
x = self.norm(x)
h = F.relu(self.gc1(x, edge_index, edge_weight))
h = self.bn1(h)
h = F.dropout(h, p=self.dropout, training=self.training)
h = F.relu(self.gc2(h, edge_index, edge_weight))
h = F.dropout(h, p=self.dropout, training=self.training)
return self.gc_mu(h, edge_index, edge_weight), self.gc_log(h, edge_index, edge_weight)
# ---------------------------------------------------------------------------
# Backward compatibility alias
# ---------------------------------------------------------------------------
Encoder = GCNEncoder
# ---------------------------------------------------------------------------
# Factory
# ---------------------------------------------------------------------------
_ENCODERS = {
"gcn": GCNEncoder,
"gat": GATEncoder,
"deep_gcn": DeepGCNEncoder,
}
def build_encoder(name: str, in_dim: int, hidden: int, latent: int, **kwargs) -> nn.Module:
"""Instantiate an encoder by name.
Parameters
----------
name : {"gcn", "gat", "deep_gcn"}
in_dim, hidden, latent : layer dimensions
**kwargs : passed to the constructor (e.g. dropout=0.3, heads=8)
"""
name = name.lower()
if name not in _ENCODERS:
raise ValueError(f"Unknown encoder '{name}'. Choose from {list(_ENCODERS)}")
return _ENCODERS[name](in_dim=in_dim, hidden=hidden, latent=latent, **kwargs)