#!/usr/bin/env python3 """ VGAE encoder architectures for chromatin contact graphs. Exported symbols ---------------- GCNEncoder — original 2-layer GCN (kept for backward compatibility) GATEncoder — 2-layer GATv2 with multi-head attention DeepGCNEncoder — 3-layer GCN with residual BatchNorm between layers Encoder — alias for GCNEncoder (backward compat) build_encoder() — factory: returns the right class from a string name """ import torch import torch.nn as nn import torch.nn.functional as F from torch_geometric.nn import GCNConv, GATv2Conv # --------------------------------------------------------------------------- # GCN encoder (baseline) # --------------------------------------------------------------------------- class GCNEncoder(nn.Module): """Two-layer GCN encoder with input BatchNorm. Architecture: BatchNorm → GCNConv(hidden) → ReLU → Dropout → GCNConv_mu / GCNConv_logstd """ def __init__(self, in_dim: int, hidden: int, latent: int, dropout: float = 0.2, **_): super().__init__() self.norm = nn.BatchNorm1d(in_dim) self.gc1 = GCNConv(in_dim, hidden, add_self_loops=True, normalize=True) self.gc_mu = GCNConv(hidden, latent, add_self_loops=True, normalize=True) self.gc_log = GCNConv(hidden, latent, add_self_loops=True, normalize=True) self.dropout = dropout def forward(self, x, edge_index, edge_weight=None): x = self.norm(x) h = F.relu(self.gc1(x, edge_index, edge_weight)) h = F.dropout(h, p=self.dropout, training=self.training) return self.gc_mu(h, edge_index, edge_weight), self.gc_log(h, edge_index, edge_weight) # --------------------------------------------------------------------------- # GAT encoder (preferred for Hi-C: handles degree heterogeneity via attention) # --------------------------------------------------------------------------- class GATEncoder(nn.Module): """Two-layer GATv2 encoder. Each GATv2 layer applies multi-head attention, which lets the model up-weight high-frequency contacts at TAD boundaries and CTCF anchors rather than averaging all neighbours uniformly (as GCN does). Architecture: BatchNorm → GATv2(hidden, heads) → ELU → BN → Dropout → GATv2(hidden, heads) → Dropout → GCNConv_mu / GCNConv_logstd """ def __init__(self, in_dim: int, hidden: int, latent: int, heads: int = 4, dropout: float = 0.2, **_): super().__init__() if hidden % heads != 0: raise ValueError(f"hidden ({hidden}) must be divisible by heads ({heads})") self.norm = nn.BatchNorm1d(in_dim) self.gat1 = GATv2Conv(in_dim, hidden // heads, heads=heads, dropout=dropout, add_self_loops=True, concat=True) self.bn1 = nn.BatchNorm1d(hidden) self.gat2 = GATv2Conv(hidden, hidden // heads, heads=heads, dropout=dropout, add_self_loops=True, concat=True) self.gc_mu = GCNConv(hidden, latent, add_self_loops=True, normalize=True) self.gc_log = GCNConv(hidden, latent, add_self_loops=True, normalize=True) self.dropout = dropout def forward(self, x, edge_index, edge_weight=None): x = self.norm(x) h = F.elu(self.gat1(x, edge_index)) h = self.bn1(h) h = F.dropout(h, p=self.dropout, training=self.training) h = F.elu(self.gat2(h, edge_index)) h = F.dropout(h, p=self.dropout, training=self.training) # GATv2 learns its own attention weights; edge_weight is used only in the # final linear projection layers (mu/log) where GCNConv accepts it. return self.gc_mu(h, edge_index, edge_weight), self.gc_log(h, edge_index, edge_weight) # --------------------------------------------------------------------------- # Deep GCN encoder (3 message-passing layers) # --------------------------------------------------------------------------- class DeepGCNEncoder(nn.Module): """Three-layer GCN encoder — wider receptive field than the baseline. Architecture: BatchNorm → GCN1 → BN → ReLU → Dropout → GCN2 → ReLU → Dropout → GCNConv_mu / GCNConv_logstd """ def __init__(self, in_dim: int, hidden: int, latent: int, dropout: float = 0.2, **_): super().__init__() self.norm = nn.BatchNorm1d(in_dim) self.gc1 = GCNConv(in_dim, hidden, add_self_loops=True, normalize=True) self.bn1 = nn.BatchNorm1d(hidden) self.gc2 = GCNConv(hidden, hidden, add_self_loops=True, normalize=True) self.gc_mu = GCNConv(hidden, latent, add_self_loops=True, normalize=True) self.gc_log = GCNConv(hidden, latent, add_self_loops=True, normalize=True) self.dropout = dropout def forward(self, x, edge_index, edge_weight=None): x = self.norm(x) h = F.relu(self.gc1(x, edge_index, edge_weight)) h = self.bn1(h) h = F.dropout(h, p=self.dropout, training=self.training) h = F.relu(self.gc2(h, edge_index, edge_weight)) h = F.dropout(h, p=self.dropout, training=self.training) return self.gc_mu(h, edge_index, edge_weight), self.gc_log(h, edge_index, edge_weight) # --------------------------------------------------------------------------- # Backward compatibility alias # --------------------------------------------------------------------------- Encoder = GCNEncoder # --------------------------------------------------------------------------- # Factory # --------------------------------------------------------------------------- _ENCODERS = { "gcn": GCNEncoder, "gat": GATEncoder, "deep_gcn": DeepGCNEncoder, } def build_encoder(name: str, in_dim: int, hidden: int, latent: int, **kwargs) -> nn.Module: """Instantiate an encoder by name. Parameters ---------- name : {"gcn", "gat", "deep_gcn"} in_dim, hidden, latent : layer dimensions **kwargs : passed to the constructor (e.g. dropout=0.3, heads=8) """ name = name.lower() if name not in _ENCODERS: raise ValueError(f"Unknown encoder '{name}'. Choose from {list(_ENCODERS)}") return _ENCODERS[name](in_dim=in_dim, hidden=hidden, latent=latent, **kwargs)