#!/usr/bin/env python3 """ Train a Variational Graph Autoencoder (VGAE) on a chromatin contact graph. Key improvements over v1 ------------------------ • GAT / DeepGCN encoders selectable via --encoder • β-VGAE with linear KL warm-up (--kl_anneal): lets the encoder learn structure before the prior regularises the latent space • Lower default LR (3e-4) and higher patience so the optimiser doesn't overshoot the minimum • --beta controls the final KL weight (default 1.0; try 0.5 to prevent posterior collapse on sparse graphs) Inputs ------ PyTorch Geometric Data object from build_graph.py Outputs (under --outdir) ------------------------ model.pt trained state_dict (+ encoder type stored in metrics.json) emb.npy node embeddings — mu vector, shape [num_nodes, latent_dim] metrics.json val/test AUC & AP, all hyperparameters """ import argparse import json import os import sys import numpy as np import torch from torch_geometric.nn.models import VGAE from torch_geometric.transforms import RandomLinkSplit from torch_geometric.utils import ( negative_sampling, remove_self_loops, to_undirected, ) from sklearn.metrics import average_precision_score, roc_auc_score sys.path.insert(0, os.path.dirname(__file__)) from model import build_encoder @torch.no_grad() def _eval_linkpred(z, pos_edges, neg_edges): def _score(edges): src, dst = edges return torch.sigmoid((z[src] * z[dst]).sum(dim=1)).cpu().numpy() y_true = np.concatenate([np.ones(pos_edges.size(1)), np.zeros(neg_edges.size(1))]) y_pred = np.concatenate([_score(pos_edges), _score(neg_edges)]) return roc_auc_score(y_true, y_pred), average_precision_score(y_true, y_pred) def main(): ap = argparse.ArgumentParser( description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter ) # Data ap.add_argument("--graph", required=True, help="Path to Data .pt file from build_graph.py") # Architecture ap.add_argument("--encoder", default="gat", choices=["gcn", "gat", "deep_gcn"], help="Encoder architecture (default: gat)") ap.add_argument("--hidden", type=int, default=128) ap.add_argument("--latent", type=int, default=64) ap.add_argument("--heads", type=int, default=4, help="Number of attention heads (GAT only)") ap.add_argument("--dropout", type=float, default=0.3) # Training ap.add_argument("--epochs", type=int, default=500) ap.add_argument("--patience", type=int, default=50, help="Early-stopping patience (val-AUC epochs)") ap.add_argument("--lr", type=float, default=3e-4) ap.add_argument("--beta", type=float, default=1.0, help="Final KL weight in the ELBO (β-VGAE). " "Values < 1 reduce regularisation on sparse graphs.") ap.add_argument("--kl_anneal",type=int, default=100, help="Linearly warm up KL weight from 0 → beta over " "this many epochs (0 = no annealing).") ap.add_argument("--seed", type=int, default=42) ap.add_argument("--outdir", default="results") args = ap.parse_args() torch.manual_seed(args.seed) np.random.seed(args.seed) os.makedirs(args.outdir, exist_ok=True) # ---- Load and clean graph ---- data = torch.load(args.graph, weights_only=False) ei, _ = remove_self_loops(data.edge_index) # Propagate edge_weight through to_undirected so split weights stay aligned if hasattr(data, "edge_weight") and data.edge_weight is not None: ei, ew = to_undirected(ei, data.edge_weight, num_nodes=data.num_nodes) data.edge_weight = ew else: ei = to_undirected(ei, num_nodes=data.num_nodes) data.edge_index = ei x = data.x.float() print(f"Graph: {data.num_nodes} nodes " f"{data.edge_index.shape[1]} edges " f"{x.shape[1]} node features") print(f"Encoder: {args.encoder} hidden={args.hidden} latent={args.latent}" + (f" heads={args.heads}" if args.encoder == "gat" else "")) # ---- Edge splits ---- splitter = RandomLinkSplit( num_val=0.1, num_test=0.1, is_undirected=True, add_negative_train_samples=False, split_labels=False, ) train_data, val_data, test_data = splitter(data) train_data.pos_edge_index = train_data.edge_index train_ew = getattr(train_data, "edge_weight", None) full_ew = getattr(data, "edge_weight", None) for split in (val_data, test_data): split.pos_edge_index = split.edge_index split.neg_edge_index = negative_sampling( edge_index=split.edge_index, num_nodes=data.num_nodes, num_neg_samples=split.edge_index.size(1), method="sparse", ) # ---- Model ---- enc = build_encoder(args.encoder, in_dim=x.size(1), hidden=args.hidden, latent=args.latent, dropout=args.dropout, heads=args.heads) model = VGAE(enc) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode="max", factor=0.5, patience=15, verbose=False ) # ---- Training loop with β-VGAE KL warm-up ---- best_val_auc = -1.0 best_state = None no_improve = 0 epochs_ran = 0 for epoch in range(1, args.epochs + 1): # Linear KL warm-up: β rises from 0 to args.beta over kl_anneal epochs if args.kl_anneal > 0: kl_w = min(args.beta, args.beta * epoch / args.kl_anneal) else: kl_w = args.beta model.train() optimizer.zero_grad() z = model.encode(x, train_data.edge_index, train_ew) loss = (model.recon_loss(z, train_data.pos_edge_index) + (kl_w / data.num_nodes) * model.kl_loss()) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() model.eval() with torch.no_grad(): z_full = model.encode(x, data.edge_index, full_ew) val_auc, val_ap = _eval_linkpred( z_full, val_data.pos_edge_index, val_data.neg_edge_index ) scheduler.step(val_auc) if val_auc > best_val_auc: best_val_auc = val_auc best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()} no_improve = 0 else: no_improve += 1 epochs_ran = epoch if epoch % 20 == 0 or epoch == 1: lr_now = optimizer.param_groups[0]["lr"] print(f"[{epoch:03d}/{args.epochs}] " f"loss={loss.item():.4f} kl_w={kl_w:.3f} " f"val AUC={val_auc:.4f} AP={val_ap:.4f} lr={lr_now:.2e}") if no_improve >= args.patience: print(f"Early stopping at epoch {epoch} " f"(no val-AUC improvement for {args.patience} epochs)") break # ---- Restore best and test ---- model.load_state_dict(best_state) model.eval() with torch.no_grad(): z_final = model.encode(x, data.edge_index, full_ew) test_auc, test_ap = _eval_linkpred( z_final, test_data.pos_edge_index, test_data.neg_edge_index ) # ---- Save outputs ---- torch.save(best_state, os.path.join(args.outdir, "model.pt")) np.save(os.path.join(args.outdir, "emb.npy"), z_final.cpu().numpy()) metrics = { "encoder": args.encoder, "val_auc": float(best_val_auc), "test_auc": float(test_auc), "test_ap": float(test_ap), "epochs_ran": epochs_ran, "epochs_max": args.epochs, "patience": args.patience, "beta": args.beta, "kl_anneal": args.kl_anneal, "hidden": args.hidden, "latent": args.latent, "heads": args.heads if args.encoder == "gat" else None, "dropout": args.dropout, "lr": args.lr, "seed": args.seed, } with open(os.path.join(args.outdir, "metrics.json"), "w") as f: json.dump(metrics, f, indent=2) print(f"\nSaved → {args.outdir}/") print(f"Embeddings shape: {z_final.shape}") print(f"Test AUC={test_auc:.4f} AP={test_ap:.4f} " f"(val best={best_val_auc:.4f})") if __name__ == "__main__": main()