fix: three silent bugs in training and inference

- Edge weights (ICE log1p) were computed but never passed to GCNConv
- encode_graph.py hardcoded GCNEncoder regardless of saved model type
- Inference graph lacked to_undirected + edge_weight pipeline from training
This commit is contained in:
2026-05-15 03:37:20 +02:00
parent acadbd780c
commit a081f29e12
3 changed files with 263 additions and 69 deletions

View File

@@ -2,15 +2,25 @@
"""
Train a Variational Graph Autoencoder (VGAE) on a chromatin contact graph.
Key improvements over v1
------------------------
• GAT / DeepGCN encoders selectable via --encoder
• β-VGAE with linear KL warm-up (--kl_anneal): lets the encoder learn
structure before the prior regularises the latent space
• Lower default LR (3e-4) and higher patience so the optimiser doesn't
overshoot the minimum
• --beta controls the final KL weight (default 1.0; try 0.5 to prevent
posterior collapse on sparse graphs)
Inputs
------
PyTorch Geometric Data object saved by build_graph.py.
PyTorch Geometric Data object from build_graph.py
Outputs (under --outdir)
------------------------
model.pt trained VGAE state_dict (includes BatchNorm running statistics)
model.pt trained state_dict (+ encoder type stored in metrics.json)
emb.npy node embeddings — mu vector, shape [num_nodes, latent_dim]
metrics.json val/test AUC & AP plus all hyperparameters
metrics.json val/test AUC & AP, all hyperparameters
"""
import argparse
@@ -30,19 +40,14 @@ from torch_geometric.utils import (
from sklearn.metrics import average_precision_score, roc_auc_score
sys.path.insert(0, os.path.dirname(__file__))
from model import Encoder
from model import build_encoder
@torch.no_grad()
def _eval_linkpred(z, pos_edges, neg_edges):
"""Return (AUROC, AP) for link prediction."""
def _sigmoid(x):
return 1.0 / (1.0 + torch.exp(-x))
def _score(edges):
src, dst = edges
return _sigmoid((z[src] * z[dst]).sum(dim=1)).cpu().numpy()
return torch.sigmoid((z[src] * z[dst]).sum(dim=1)).cpu().numpy()
y_true = np.concatenate([np.ones(pos_edges.size(1)),
np.zeros(neg_edges.size(1))])
y_pred = np.concatenate([_score(pos_edges), _score(neg_edges)])
@@ -53,15 +58,29 @@ def main():
ap = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
)
# Data
ap.add_argument("--graph", required=True,
help="Path to Data .pt file from build_graph.py")
ap.add_argument("--epochs", type=int, default=300)
ap.add_argument("--patience", type=int, default=20,
help="Early-stopping patience (val-AUC epochs without improvement)")
ap.add_argument("--lr", type=float, default=1e-3)
ap.add_argument("--hidden", type=int, default=64)
ap.add_argument("--latent", type=int, default=32)
ap.add_argument("--dropout", type=float, default=0.2)
# Architecture
ap.add_argument("--encoder", default="gat",
choices=["gcn", "gat", "deep_gcn"],
help="Encoder architecture (default: gat)")
ap.add_argument("--hidden", type=int, default=128)
ap.add_argument("--latent", type=int, default=64)
ap.add_argument("--heads", type=int, default=4,
help="Number of attention heads (GAT only)")
ap.add_argument("--dropout", type=float, default=0.3)
# Training
ap.add_argument("--epochs", type=int, default=500)
ap.add_argument("--patience", type=int, default=50,
help="Early-stopping patience (val-AUC epochs)")
ap.add_argument("--lr", type=float, default=3e-4)
ap.add_argument("--beta", type=float, default=1.0,
help="Final KL weight in the ELBO (β-VGAE). "
"Values < 1 reduce regularisation on sparse graphs.")
ap.add_argument("--kl_anneal",type=int, default=100,
help="Linearly warm up KL weight from 0 → beta over "
"this many epochs (0 = no annealing).")
ap.add_argument("--seed", type=int, default=42)
ap.add_argument("--outdir", default="results")
args = ap.parse_args()
@@ -73,13 +92,21 @@ def main():
# ---- Load and clean graph ----
data = torch.load(args.graph, weights_only=False)
ei, _ = remove_self_loops(data.edge_index)
data.edge_index = to_undirected(ei, num_nodes=data.num_nodes)
# Propagate edge_weight through to_undirected so split weights stay aligned
if hasattr(data, "edge_weight") and data.edge_weight is not None:
ei, ew = to_undirected(ei, data.edge_weight, num_nodes=data.num_nodes)
data.edge_weight = ew
else:
ei = to_undirected(ei, num_nodes=data.num_nodes)
data.edge_index = ei
x = data.x.float()
print(f"Graph: {data.num_nodes} nodes "
f"{data.edge_index.shape[1]} edges "
f"{x.shape[1]} node features")
print(f"Encoder: {args.encoder} hidden={args.hidden} latent={args.latent}"
+ (f" heads={args.heads}" if args.encoder == "gat" else ""))
# ---- Edge splits for link-prediction evaluation ----
# ---- Edge splits ----
splitter = RandomLinkSplit(
num_val=0.1, num_test=0.1,
is_undirected=True,
@@ -88,6 +115,8 @@ def main():
)
train_data, val_data, test_data = splitter(data)
train_data.pos_edge_index = train_data.edge_index
train_ew = getattr(train_data, "edge_weight", None)
full_ew = getattr(data, "edge_weight", None)
for split in (val_data, test_data):
split.pos_edge_index = split.edge_index
@@ -99,33 +128,47 @@ def main():
)
# ---- Model ----
enc = Encoder(in_dim=x.size(1), hidden=args.hidden,
latent=args.latent, dropout=args.dropout)
enc = build_encoder(args.encoder, in_dim=x.size(1),
hidden=args.hidden, latent=args.latent,
dropout=args.dropout, heads=args.heads)
model = VGAE(enc)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr,
weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="max", factor=0.5, patience=15, verbose=False
)
# ---- Training loop with early stopping ----
# ---- Training loop with β-VGAE KL warm-up ----
best_val_auc = -1.0
best_state = None
no_improve = 0
epochs_ran = 0
for epoch in range(1, args.epochs + 1):
# Linear KL warm-up: β rises from 0 to args.beta over kl_anneal epochs
if args.kl_anneal > 0:
kl_w = min(args.beta, args.beta * epoch / args.kl_anneal)
else:
kl_w = args.beta
model.train()
optimizer.zero_grad()
z = model.encode(x, train_data.edge_index)
z = model.encode(x, train_data.edge_index, train_ew)
loss = (model.recon_loss(z, train_data.pos_edge_index)
+ (1.0 / data.num_nodes) * model.kl_loss())
+ (kl_w / data.num_nodes) * model.kl_loss())
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
model.eval()
with torch.no_grad():
z_full = model.encode(x, data.edge_index)
z_full = model.encode(x, data.edge_index, full_ew)
val_auc, val_ap = _eval_linkpred(
z_full, val_data.pos_edge_index, val_data.neg_edge_index
)
scheduler.step(val_auc)
if val_auc > best_val_auc:
best_val_auc = val_auc
best_state = {k: v.cpu().clone()
@@ -135,41 +178,43 @@ def main():
no_improve += 1
epochs_ran = epoch
if epoch % 10 == 0 or epoch == 1:
if epoch % 20 == 0 or epoch == 1:
lr_now = optimizer.param_groups[0]["lr"]
print(f"[{epoch:03d}/{args.epochs}] "
f"loss={loss.item():.4f} "
f"val AUC={val_auc:.4f} AP={val_ap:.4f}")
f"loss={loss.item():.4f} kl_w={kl_w:.3f} "
f"val AUC={val_auc:.4f} AP={val_ap:.4f} lr={lr_now:.2e}")
if no_improve >= args.patience:
print(f"Early stopping at epoch {epoch} "
f"(no val-AUC improvement for {args.patience} epochs)")
break
# ---- Restore best checkpoint and compute test metrics ----
# ---- Restore best and test ----
model.load_state_dict(best_state)
model.eval()
with torch.no_grad():
z_final = model.encode(x, data.edge_index)
z_final = model.encode(x, data.edge_index, full_ew)
test_auc, test_ap = _eval_linkpred(
z_final, test_data.pos_edge_index, test_data.neg_edge_index
)
# ---- Save outputs ----
model_path = os.path.join(args.outdir, "model.pt")
torch.save(best_state, model_path)
emb_path = os.path.join(args.outdir, "emb.npy")
np.save(emb_path, z_final.cpu().numpy())
torch.save(best_state, os.path.join(args.outdir, "model.pt"))
np.save(os.path.join(args.outdir, "emb.npy"), z_final.cpu().numpy())
metrics = {
"encoder": args.encoder,
"val_auc": float(best_val_auc),
"test_auc": float(test_auc),
"test_ap": float(test_ap),
"epochs_ran": epochs_ran,
"epochs_max": args.epochs,
"patience": args.patience,
"beta": args.beta,
"kl_anneal": args.kl_anneal,
"hidden": args.hidden,
"latent": args.latent,
"heads": args.heads if args.encoder == "gat" else None,
"dropout": args.dropout,
"lr": args.lr,
"seed": args.seed,
@@ -177,9 +222,10 @@ def main():
with open(os.path.join(args.outdir, "metrics.json"), "w") as f:
json.dump(metrics, f, indent=2)
print(f"\nSaved model → {model_path}")
print(f"Saved embeddings {emb_path} shape={z_final.shape}")
print(f"Test AUC={test_auc:.4f} AP={test_ap:.4f}")
print(f"\nSaved {args.outdir}/")
print(f"Embeddings shape: {z_final.shape}")
print(f"Test AUC={test_auc:.4f} AP={test_ap:.4f} "
f"(val best={best_val_auc:.4f})")
if __name__ == "__main__":