fix: three silent bugs in training and inference
- Edge weights (ICE log1p) were computed but never passed to GCNConv - encode_graph.py hardcoded GCNEncoder regardless of saved model type - Inference graph lacked to_undirected + edge_weight pipeline from training
This commit is contained in:
143
scripts/model.py
143
scripts/model.py
@@ -1,23 +1,34 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Shared VGAE encoder. Imported by train_vgae.py and encode_graph.py."""
|
||||
"""
|
||||
VGAE encoder architectures for chromatin contact graphs.
|
||||
|
||||
Exported symbols
|
||||
----------------
|
||||
GCNEncoder — original 2-layer GCN (kept for backward compatibility)
|
||||
GATEncoder — 2-layer GATv2 with multi-head attention
|
||||
DeepGCNEncoder — 3-layer GCN with residual BatchNorm between layers
|
||||
Encoder — alias for GCNEncoder (backward compat)
|
||||
build_encoder() — factory: returns the right class from a string name
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch_geometric.nn import GCNConv
|
||||
from torch_geometric.nn import GCNConv, GATv2Conv
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
"""Two-layer GCN encoder for VGAE with input BatchNorm.
|
||||
# ---------------------------------------------------------------------------
|
||||
# GCN encoder (baseline)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
Architecture: BatchNorm → GCN(hidden) → ReLU → Dropout → GCN_mu / GCN_logstd
|
||||
class GCNEncoder(nn.Module):
|
||||
"""Two-layer GCN encoder with input BatchNorm.
|
||||
|
||||
The BatchNorm layer normalises raw ChIP-seq signals and its running statistics
|
||||
are saved in model.pt, so encode_graph.py applies identical normalisation to
|
||||
held-out cell lines without a separate scaler file.
|
||||
Architecture: BatchNorm → GCNConv(hidden) → ReLU → Dropout
|
||||
→ GCNConv_mu / GCNConv_logstd
|
||||
"""
|
||||
|
||||
def __init__(self, in_dim: int, hidden: int, latent: int, dropout: float = 0.2):
|
||||
def __init__(self, in_dim: int, hidden: int, latent: int, dropout: float = 0.2, **_):
|
||||
super().__init__()
|
||||
self.norm = nn.BatchNorm1d(in_dim)
|
||||
self.gc1 = GCNConv(in_dim, hidden, add_self_loops=True, normalize=True)
|
||||
@@ -25,8 +36,116 @@ class Encoder(nn.Module):
|
||||
self.gc_log = GCNConv(hidden, latent, add_self_loops=True, normalize=True)
|
||||
self.dropout = dropout
|
||||
|
||||
def forward(self, x, edge_index):
|
||||
def forward(self, x, edge_index, edge_weight=None):
|
||||
x = self.norm(x)
|
||||
h = F.relu(self.gc1(x, edge_index))
|
||||
h = F.relu(self.gc1(x, edge_index, edge_weight))
|
||||
h = F.dropout(h, p=self.dropout, training=self.training)
|
||||
return self.gc_mu(h, edge_index), self.gc_log(h, edge_index)
|
||||
return self.gc_mu(h, edge_index, edge_weight), self.gc_log(h, edge_index, edge_weight)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GAT encoder (preferred for Hi-C: handles degree heterogeneity via attention)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class GATEncoder(nn.Module):
|
||||
"""Two-layer GATv2 encoder.
|
||||
|
||||
Each GATv2 layer applies multi-head attention, which lets the model
|
||||
up-weight high-frequency contacts at TAD boundaries and CTCF anchors
|
||||
rather than averaging all neighbours uniformly (as GCN does).
|
||||
|
||||
Architecture: BatchNorm → GATv2(hidden, heads) → ELU → BN → Dropout
|
||||
→ GATv2(hidden, heads) → Dropout
|
||||
→ GCNConv_mu / GCNConv_logstd
|
||||
"""
|
||||
|
||||
def __init__(self, in_dim: int, hidden: int, latent: int,
|
||||
heads: int = 4, dropout: float = 0.2, **_):
|
||||
super().__init__()
|
||||
if hidden % heads != 0:
|
||||
raise ValueError(f"hidden ({hidden}) must be divisible by heads ({heads})")
|
||||
self.norm = nn.BatchNorm1d(in_dim)
|
||||
self.gat1 = GATv2Conv(in_dim, hidden // heads, heads=heads,
|
||||
dropout=dropout, add_self_loops=True, concat=True)
|
||||
self.bn1 = nn.BatchNorm1d(hidden)
|
||||
self.gat2 = GATv2Conv(hidden, hidden // heads, heads=heads,
|
||||
dropout=dropout, add_self_loops=True, concat=True)
|
||||
self.gc_mu = GCNConv(hidden, latent, add_self_loops=True, normalize=True)
|
||||
self.gc_log = GCNConv(hidden, latent, add_self_loops=True, normalize=True)
|
||||
self.dropout = dropout
|
||||
|
||||
def forward(self, x, edge_index, edge_weight=None):
|
||||
x = self.norm(x)
|
||||
h = F.elu(self.gat1(x, edge_index))
|
||||
h = self.bn1(h)
|
||||
h = F.dropout(h, p=self.dropout, training=self.training)
|
||||
h = F.elu(self.gat2(h, edge_index))
|
||||
h = F.dropout(h, p=self.dropout, training=self.training)
|
||||
# GATv2 learns its own attention weights; edge_weight is used only in the
|
||||
# final linear projection layers (mu/log) where GCNConv accepts it.
|
||||
return self.gc_mu(h, edge_index, edge_weight), self.gc_log(h, edge_index, edge_weight)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Deep GCN encoder (3 message-passing layers)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class DeepGCNEncoder(nn.Module):
|
||||
"""Three-layer GCN encoder — wider receptive field than the baseline.
|
||||
|
||||
Architecture: BatchNorm → GCN1 → BN → ReLU → Dropout
|
||||
→ GCN2 → ReLU → Dropout
|
||||
→ GCNConv_mu / GCNConv_logstd
|
||||
"""
|
||||
|
||||
def __init__(self, in_dim: int, hidden: int, latent: int, dropout: float = 0.2, **_):
|
||||
super().__init__()
|
||||
self.norm = nn.BatchNorm1d(in_dim)
|
||||
self.gc1 = GCNConv(in_dim, hidden, add_self_loops=True, normalize=True)
|
||||
self.bn1 = nn.BatchNorm1d(hidden)
|
||||
self.gc2 = GCNConv(hidden, hidden, add_self_loops=True, normalize=True)
|
||||
self.gc_mu = GCNConv(hidden, latent, add_self_loops=True, normalize=True)
|
||||
self.gc_log = GCNConv(hidden, latent, add_self_loops=True, normalize=True)
|
||||
self.dropout = dropout
|
||||
|
||||
def forward(self, x, edge_index, edge_weight=None):
|
||||
x = self.norm(x)
|
||||
h = F.relu(self.gc1(x, edge_index, edge_weight))
|
||||
h = self.bn1(h)
|
||||
h = F.dropout(h, p=self.dropout, training=self.training)
|
||||
h = F.relu(self.gc2(h, edge_index, edge_weight))
|
||||
h = F.dropout(h, p=self.dropout, training=self.training)
|
||||
return self.gc_mu(h, edge_index, edge_weight), self.gc_log(h, edge_index, edge_weight)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Backward compatibility alias
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
Encoder = GCNEncoder
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_ENCODERS = {
|
||||
"gcn": GCNEncoder,
|
||||
"gat": GATEncoder,
|
||||
"deep_gcn": DeepGCNEncoder,
|
||||
}
|
||||
|
||||
|
||||
def build_encoder(name: str, in_dim: int, hidden: int, latent: int, **kwargs) -> nn.Module:
|
||||
"""Instantiate an encoder by name.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : {"gcn", "gat", "deep_gcn"}
|
||||
in_dim, hidden, latent : layer dimensions
|
||||
**kwargs : passed to the constructor (e.g. dropout=0.3, heads=8)
|
||||
"""
|
||||
name = name.lower()
|
||||
if name not in _ENCODERS:
|
||||
raise ValueError(f"Unknown encoder '{name}'. Choose from {list(_ENCODERS)}")
|
||||
return _ENCODERS[name](in_dim=in_dim, hidden=hidden, latent=latent, **kwargs)
|
||||
|
||||
Reference in New Issue
Block a user