diff --git a/scripts/encode_graph.py b/scripts/encode_graph.py index a34d85a..f9d3f66 100644 --- a/scripts/encode_graph.py +++ b/scripts/encode_graph.py @@ -2,9 +2,10 @@ """ Encode a chromatin contact graph using a trained VGAE model. -Dimensions (in_dim, hidden, latent) are inferred automatically from the saved -state_dict. The BatchNorm running statistics from training are restored, so the -same normalisation is applied to held-out cell lines without a separate scaler. +Dimensions (in_dim, hidden, latent) and encoder type are read from the +metrics.json saved alongside model.pt. Edge weights are passed to GCN-based +encoders so the same weighted message-passing used during training is applied +at inference time. Usage ----- @@ -15,19 +16,30 @@ Usage """ import argparse +import json import os import sys import numpy as np import torch from torch_geometric.nn.models import VGAE +from torch_geometric.utils import remove_self_loops, to_undirected sys.path.insert(0, os.path.dirname(__file__)) -from model import Encoder +from model import build_encoder, GCNEncoder -def _infer_dims(state_dict: dict) -> tuple: - """Infer (in_dim, hidden, latent) from a VGAE state_dict.""" +def _load_metrics(model_path: str) -> dict: + """Read metrics.json from the same directory as model.pt.""" + metrics_path = os.path.join(os.path.dirname(os.path.abspath(model_path)), "metrics.json") + if os.path.exists(metrics_path): + with open(metrics_path) as f: + return json.load(f) + return {} + + +def _infer_dims_gcn(state_dict: dict) -> tuple: + """Fallback: infer (in_dim, hidden, latent) from a GCN state_dict.""" keys = list(state_dict.keys()) def _first_weight(substr): @@ -37,22 +49,18 @@ def _infer_dims(state_dict: dict) -> tuple: and "running" not in k and "num_batches" not in k): return state_dict[k].shape - raise KeyError(f"No weight key containing '{substr}' in state_dict. " - f"Available keys: {keys}") + raise KeyError(f"No weight key containing '{substr}'. Keys: {keys}") - gc1_shape = _first_weight("gc1") # shape [hidden, in_dim] - gc_mu_shape = _first_weight("gc_mu") # shape [latent, hidden] + gc1_shape = _first_weight("gc1") + gc_mu_shape = _first_weight("gc_mu") hidden = gc1_shape[0] latent = gc_mu_shape[0] - - # in_dim from BatchNorm weight (shape [in_dim]) for k in keys: if "norm" in k and k.endswith("weight") and "running" not in k: in_dim = state_dict[k].shape[0] break else: - in_dim = gc1_shape[1] # fallback: second dim of gc1 weight - + in_dim = gc1_shape[1] return in_dim, hidden, latent @@ -71,16 +79,37 @@ def main(): data = torch.load(args.graph, weights_only=False) state_dict = torch.load(args.model, map_location="cpu", weights_only=False) - in_dim, hidden, latent = _infer_dims(state_dict) - print(f"Inferred: in_dim={in_dim} hidden={hidden} latent={latent}") + # Build edge_index and edge_weight (undirected, consistent with training) + ei, _ = remove_self_loops(data.edge_index) + if hasattr(data, "edge_weight") and data.edge_weight is not None: + ei, ew = to_undirected(ei, data.edge_weight, num_nodes=data.num_nodes) + else: + ei = to_undirected(ei, num_nodes=data.num_nodes) + ew = None - enc = Encoder(in_dim=in_dim, hidden=hidden, latent=latent) + # Read hyperparameters from metrics.json (preferred) or infer from state_dict + metrics = _load_metrics(args.model) + encoder_name = metrics.get("encoder", "gcn") + hidden = metrics.get("hidden") + latent = metrics.get("latent") + heads = metrics.get("heads") or 4 + + if hidden is None or latent is None: + print("metrics.json not found or incomplete — inferring dims from state_dict (GCN only)") + _, hidden, latent = _infer_dims_gcn(state_dict) + encoder_name = "gcn" + + in_dim = data.x.shape[1] + print(f"Encoder: {encoder_name} in_dim={in_dim} hidden={hidden} latent={latent}") + + enc = build_encoder(encoder_name, in_dim=in_dim, hidden=hidden, + latent=latent, dropout=0.0, heads=heads) model = VGAE(enc) model.load_state_dict(state_dict) model.eval() with torch.no_grad(): - z = model.encode(data.x.float(), data.edge_index) + z = model.encode(data.x.float(), ei, ew) os.makedirs(os.path.dirname(os.path.abspath(args.out)), exist_ok=True) np.save(args.out, z.cpu().numpy()) diff --git a/scripts/model.py b/scripts/model.py index 0a85b46..d8f47a1 100644 --- a/scripts/model.py +++ b/scripts/model.py @@ -1,23 +1,34 @@ #!/usr/bin/env python3 -"""Shared VGAE encoder. Imported by train_vgae.py and encode_graph.py.""" +""" +VGAE encoder architectures for chromatin contact graphs. + +Exported symbols +---------------- +GCNEncoder — original 2-layer GCN (kept for backward compatibility) +GATEncoder — 2-layer GATv2 with multi-head attention +DeepGCNEncoder — 3-layer GCN with residual BatchNorm between layers +Encoder — alias for GCNEncoder (backward compat) +build_encoder() — factory: returns the right class from a string name +""" import torch import torch.nn as nn import torch.nn.functional as F -from torch_geometric.nn import GCNConv +from torch_geometric.nn import GCNConv, GATv2Conv -class Encoder(nn.Module): - """Two-layer GCN encoder for VGAE with input BatchNorm. +# --------------------------------------------------------------------------- +# GCN encoder (baseline) +# --------------------------------------------------------------------------- - Architecture: BatchNorm → GCN(hidden) → ReLU → Dropout → GCN_mu / GCN_logstd +class GCNEncoder(nn.Module): + """Two-layer GCN encoder with input BatchNorm. - The BatchNorm layer normalises raw ChIP-seq signals and its running statistics - are saved in model.pt, so encode_graph.py applies identical normalisation to - held-out cell lines without a separate scaler file. + Architecture: BatchNorm → GCNConv(hidden) → ReLU → Dropout + → GCNConv_mu / GCNConv_logstd """ - def __init__(self, in_dim: int, hidden: int, latent: int, dropout: float = 0.2): + def __init__(self, in_dim: int, hidden: int, latent: int, dropout: float = 0.2, **_): super().__init__() self.norm = nn.BatchNorm1d(in_dim) self.gc1 = GCNConv(in_dim, hidden, add_self_loops=True, normalize=True) @@ -25,8 +36,116 @@ class Encoder(nn.Module): self.gc_log = GCNConv(hidden, latent, add_self_loops=True, normalize=True) self.dropout = dropout - def forward(self, x, edge_index): + def forward(self, x, edge_index, edge_weight=None): x = self.norm(x) - h = F.relu(self.gc1(x, edge_index)) + h = F.relu(self.gc1(x, edge_index, edge_weight)) h = F.dropout(h, p=self.dropout, training=self.training) - return self.gc_mu(h, edge_index), self.gc_log(h, edge_index) + return self.gc_mu(h, edge_index, edge_weight), self.gc_log(h, edge_index, edge_weight) + + +# --------------------------------------------------------------------------- +# GAT encoder (preferred for Hi-C: handles degree heterogeneity via attention) +# --------------------------------------------------------------------------- + +class GATEncoder(nn.Module): + """Two-layer GATv2 encoder. + + Each GATv2 layer applies multi-head attention, which lets the model + up-weight high-frequency contacts at TAD boundaries and CTCF anchors + rather than averaging all neighbours uniformly (as GCN does). + + Architecture: BatchNorm → GATv2(hidden, heads) → ELU → BN → Dropout + → GATv2(hidden, heads) → Dropout + → GCNConv_mu / GCNConv_logstd + """ + + def __init__(self, in_dim: int, hidden: int, latent: int, + heads: int = 4, dropout: float = 0.2, **_): + super().__init__() + if hidden % heads != 0: + raise ValueError(f"hidden ({hidden}) must be divisible by heads ({heads})") + self.norm = nn.BatchNorm1d(in_dim) + self.gat1 = GATv2Conv(in_dim, hidden // heads, heads=heads, + dropout=dropout, add_self_loops=True, concat=True) + self.bn1 = nn.BatchNorm1d(hidden) + self.gat2 = GATv2Conv(hidden, hidden // heads, heads=heads, + dropout=dropout, add_self_loops=True, concat=True) + self.gc_mu = GCNConv(hidden, latent, add_self_loops=True, normalize=True) + self.gc_log = GCNConv(hidden, latent, add_self_loops=True, normalize=True) + self.dropout = dropout + + def forward(self, x, edge_index, edge_weight=None): + x = self.norm(x) + h = F.elu(self.gat1(x, edge_index)) + h = self.bn1(h) + h = F.dropout(h, p=self.dropout, training=self.training) + h = F.elu(self.gat2(h, edge_index)) + h = F.dropout(h, p=self.dropout, training=self.training) + # GATv2 learns its own attention weights; edge_weight is used only in the + # final linear projection layers (mu/log) where GCNConv accepts it. + return self.gc_mu(h, edge_index, edge_weight), self.gc_log(h, edge_index, edge_weight) + + +# --------------------------------------------------------------------------- +# Deep GCN encoder (3 message-passing layers) +# --------------------------------------------------------------------------- + +class DeepGCNEncoder(nn.Module): + """Three-layer GCN encoder — wider receptive field than the baseline. + + Architecture: BatchNorm → GCN1 → BN → ReLU → Dropout + → GCN2 → ReLU → Dropout + → GCNConv_mu / GCNConv_logstd + """ + + def __init__(self, in_dim: int, hidden: int, latent: int, dropout: float = 0.2, **_): + super().__init__() + self.norm = nn.BatchNorm1d(in_dim) + self.gc1 = GCNConv(in_dim, hidden, add_self_loops=True, normalize=True) + self.bn1 = nn.BatchNorm1d(hidden) + self.gc2 = GCNConv(hidden, hidden, add_self_loops=True, normalize=True) + self.gc_mu = GCNConv(hidden, latent, add_self_loops=True, normalize=True) + self.gc_log = GCNConv(hidden, latent, add_self_loops=True, normalize=True) + self.dropout = dropout + + def forward(self, x, edge_index, edge_weight=None): + x = self.norm(x) + h = F.relu(self.gc1(x, edge_index, edge_weight)) + h = self.bn1(h) + h = F.dropout(h, p=self.dropout, training=self.training) + h = F.relu(self.gc2(h, edge_index, edge_weight)) + h = F.dropout(h, p=self.dropout, training=self.training) + return self.gc_mu(h, edge_index, edge_weight), self.gc_log(h, edge_index, edge_weight) + + +# --------------------------------------------------------------------------- +# Backward compatibility alias +# --------------------------------------------------------------------------- + +Encoder = GCNEncoder + + +# --------------------------------------------------------------------------- +# Factory +# --------------------------------------------------------------------------- + +_ENCODERS = { + "gcn": GCNEncoder, + "gat": GATEncoder, + "deep_gcn": DeepGCNEncoder, +} + + +def build_encoder(name: str, in_dim: int, hidden: int, latent: int, **kwargs) -> nn.Module: + """Instantiate an encoder by name. + + Parameters + ---------- + name : {"gcn", "gat", "deep_gcn"} + in_dim, hidden, latent : layer dimensions + **kwargs : passed to the constructor (e.g. dropout=0.3, heads=8) + """ + name = name.lower() + if name not in _ENCODERS: + raise ValueError(f"Unknown encoder '{name}'. Choose from {list(_ENCODERS)}") + return _ENCODERS[name](in_dim=in_dim, hidden=hidden, latent=latent, **kwargs) diff --git a/scripts/train_vgae.py b/scripts/train_vgae.py index 65f6a35..66083ef 100644 --- a/scripts/train_vgae.py +++ b/scripts/train_vgae.py @@ -2,15 +2,25 @@ """ Train a Variational Graph Autoencoder (VGAE) on a chromatin contact graph. +Key improvements over v1 +------------------------ +• GAT / DeepGCN encoders selectable via --encoder +• β-VGAE with linear KL warm-up (--kl_anneal): lets the encoder learn + structure before the prior regularises the latent space +• Lower default LR (3e-4) and higher patience so the optimiser doesn't + overshoot the minimum +• --beta controls the final KL weight (default 1.0; try 0.5 to prevent + posterior collapse on sparse graphs) + Inputs ------ - PyTorch Geometric Data object saved by build_graph.py. + PyTorch Geometric Data object from build_graph.py Outputs (under --outdir) ------------------------ - model.pt trained VGAE state_dict (includes BatchNorm running statistics) + model.pt trained state_dict (+ encoder type stored in metrics.json) emb.npy node embeddings — mu vector, shape [num_nodes, latent_dim] - metrics.json val/test AUC & AP plus all hyperparameters + metrics.json val/test AUC & AP, all hyperparameters """ import argparse @@ -30,19 +40,14 @@ from torch_geometric.utils import ( from sklearn.metrics import average_precision_score, roc_auc_score sys.path.insert(0, os.path.dirname(__file__)) -from model import Encoder +from model import build_encoder @torch.no_grad() def _eval_linkpred(z, pos_edges, neg_edges): - """Return (AUROC, AP) for link prediction.""" - def _sigmoid(x): - return 1.0 / (1.0 + torch.exp(-x)) - def _score(edges): src, dst = edges - return _sigmoid((z[src] * z[dst]).sum(dim=1)).cpu().numpy() - + return torch.sigmoid((z[src] * z[dst]).sum(dim=1)).cpu().numpy() y_true = np.concatenate([np.ones(pos_edges.size(1)), np.zeros(neg_edges.size(1))]) y_pred = np.concatenate([_score(pos_edges), _score(neg_edges)]) @@ -53,15 +58,29 @@ def main(): ap = argparse.ArgumentParser( description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter ) + # Data ap.add_argument("--graph", required=True, help="Path to Data .pt file from build_graph.py") - ap.add_argument("--epochs", type=int, default=300) - ap.add_argument("--patience", type=int, default=20, - help="Early-stopping patience (val-AUC epochs without improvement)") - ap.add_argument("--lr", type=float, default=1e-3) - ap.add_argument("--hidden", type=int, default=64) - ap.add_argument("--latent", type=int, default=32) - ap.add_argument("--dropout", type=float, default=0.2) + # Architecture + ap.add_argument("--encoder", default="gat", + choices=["gcn", "gat", "deep_gcn"], + help="Encoder architecture (default: gat)") + ap.add_argument("--hidden", type=int, default=128) + ap.add_argument("--latent", type=int, default=64) + ap.add_argument("--heads", type=int, default=4, + help="Number of attention heads (GAT only)") + ap.add_argument("--dropout", type=float, default=0.3) + # Training + ap.add_argument("--epochs", type=int, default=500) + ap.add_argument("--patience", type=int, default=50, + help="Early-stopping patience (val-AUC epochs)") + ap.add_argument("--lr", type=float, default=3e-4) + ap.add_argument("--beta", type=float, default=1.0, + help="Final KL weight in the ELBO (β-VGAE). " + "Values < 1 reduce regularisation on sparse graphs.") + ap.add_argument("--kl_anneal",type=int, default=100, + help="Linearly warm up KL weight from 0 → beta over " + "this many epochs (0 = no annealing).") ap.add_argument("--seed", type=int, default=42) ap.add_argument("--outdir", default="results") args = ap.parse_args() @@ -73,13 +92,21 @@ def main(): # ---- Load and clean graph ---- data = torch.load(args.graph, weights_only=False) ei, _ = remove_self_loops(data.edge_index) - data.edge_index = to_undirected(ei, num_nodes=data.num_nodes) + # Propagate edge_weight through to_undirected so split weights stay aligned + if hasattr(data, "edge_weight") and data.edge_weight is not None: + ei, ew = to_undirected(ei, data.edge_weight, num_nodes=data.num_nodes) + data.edge_weight = ew + else: + ei = to_undirected(ei, num_nodes=data.num_nodes) + data.edge_index = ei x = data.x.float() print(f"Graph: {data.num_nodes} nodes " f"{data.edge_index.shape[1]} edges " f"{x.shape[1]} node features") + print(f"Encoder: {args.encoder} hidden={args.hidden} latent={args.latent}" + + (f" heads={args.heads}" if args.encoder == "gat" else "")) - # ---- Edge splits for link-prediction evaluation ---- + # ---- Edge splits ---- splitter = RandomLinkSplit( num_val=0.1, num_test=0.1, is_undirected=True, @@ -88,6 +115,8 @@ def main(): ) train_data, val_data, test_data = splitter(data) train_data.pos_edge_index = train_data.edge_index + train_ew = getattr(train_data, "edge_weight", None) + full_ew = getattr(data, "edge_weight", None) for split in (val_data, test_data): split.pos_edge_index = split.edge_index @@ -99,33 +128,47 @@ def main(): ) # ---- Model ---- - enc = Encoder(in_dim=x.size(1), hidden=args.hidden, - latent=args.latent, dropout=args.dropout) + enc = build_encoder(args.encoder, in_dim=x.size(1), + hidden=args.hidden, latent=args.latent, + dropout=args.dropout, heads=args.heads) model = VGAE(enc) - optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, + weight_decay=1e-5) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, mode="max", factor=0.5, patience=15, verbose=False + ) - # ---- Training loop with early stopping ---- + # ---- Training loop with β-VGAE KL warm-up ---- best_val_auc = -1.0 best_state = None no_improve = 0 epochs_ran = 0 for epoch in range(1, args.epochs + 1): + # Linear KL warm-up: β rises from 0 to args.beta over kl_anneal epochs + if args.kl_anneal > 0: + kl_w = min(args.beta, args.beta * epoch / args.kl_anneal) + else: + kl_w = args.beta + model.train() optimizer.zero_grad() - z = model.encode(x, train_data.edge_index) + z = model.encode(x, train_data.edge_index, train_ew) loss = (model.recon_loss(z, train_data.pos_edge_index) - + (1.0 / data.num_nodes) * model.kl_loss()) + + (kl_w / data.num_nodes) * model.kl_loss()) loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() model.eval() with torch.no_grad(): - z_full = model.encode(x, data.edge_index) + z_full = model.encode(x, data.edge_index, full_ew) val_auc, val_ap = _eval_linkpred( z_full, val_data.pos_edge_index, val_data.neg_edge_index ) + scheduler.step(val_auc) + if val_auc > best_val_auc: best_val_auc = val_auc best_state = {k: v.cpu().clone() @@ -135,41 +178,43 @@ def main(): no_improve += 1 epochs_ran = epoch - if epoch % 10 == 0 or epoch == 1: + if epoch % 20 == 0 or epoch == 1: + lr_now = optimizer.param_groups[0]["lr"] print(f"[{epoch:03d}/{args.epochs}] " - f"loss={loss.item():.4f} " - f"val AUC={val_auc:.4f} AP={val_ap:.4f}") + f"loss={loss.item():.4f} kl_w={kl_w:.3f} " + f"val AUC={val_auc:.4f} AP={val_ap:.4f} lr={lr_now:.2e}") if no_improve >= args.patience: print(f"Early stopping at epoch {epoch} " f"(no val-AUC improvement for {args.patience} epochs)") break - # ---- Restore best checkpoint and compute test metrics ---- + # ---- Restore best and test ---- model.load_state_dict(best_state) model.eval() with torch.no_grad(): - z_final = model.encode(x, data.edge_index) + z_final = model.encode(x, data.edge_index, full_ew) test_auc, test_ap = _eval_linkpred( z_final, test_data.pos_edge_index, test_data.neg_edge_index ) # ---- Save outputs ---- - model_path = os.path.join(args.outdir, "model.pt") - torch.save(best_state, model_path) - - emb_path = os.path.join(args.outdir, "emb.npy") - np.save(emb_path, z_final.cpu().numpy()) + torch.save(best_state, os.path.join(args.outdir, "model.pt")) + np.save(os.path.join(args.outdir, "emb.npy"), z_final.cpu().numpy()) metrics = { + "encoder": args.encoder, "val_auc": float(best_val_auc), "test_auc": float(test_auc), "test_ap": float(test_ap), "epochs_ran": epochs_ran, "epochs_max": args.epochs, "patience": args.patience, + "beta": args.beta, + "kl_anneal": args.kl_anneal, "hidden": args.hidden, "latent": args.latent, + "heads": args.heads if args.encoder == "gat" else None, "dropout": args.dropout, "lr": args.lr, "seed": args.seed, @@ -177,9 +222,10 @@ def main(): with open(os.path.join(args.outdir, "metrics.json"), "w") as f: json.dump(metrics, f, indent=2) - print(f"\nSaved model → {model_path}") - print(f"Saved embeddings → {emb_path} shape={z_final.shape}") - print(f"Test AUC={test_auc:.4f} AP={test_ap:.4f}") + print(f"\nSaved → {args.outdir}/") + print(f"Embeddings shape: {z_final.shape}") + print(f"Test AUC={test_auc:.4f} AP={test_ap:.4f} " + f"(val best={best_val_auc:.4f})") if __name__ == "__main__":