261 lines
9.9 KiB
Python
261 lines
9.9 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Train a Variational Graph Autoencoder (VGAE) on a chromatin contact graph.
|
|
|
|
Key improvements over v1
|
|
------------------------
|
|
• GAT / DeepGCN encoders selectable via --encoder
|
|
• β-VGAE with linear KL warm-up (--kl_anneal): lets the encoder learn
|
|
structure before the prior regularises the latent space
|
|
• Lower default LR (3e-4) and higher patience so the optimiser doesn't
|
|
overshoot the minimum
|
|
• --beta controls the final KL weight (default 1.0; try 0.5 to prevent
|
|
posterior collapse on sparse graphs)
|
|
|
|
Inputs
|
|
------
|
|
PyTorch Geometric Data object from build_graph.py
|
|
|
|
Outputs (under --outdir)
|
|
------------------------
|
|
model.pt trained state_dict (+ encoder type stored in metrics.json)
|
|
emb.npy node embeddings — mu vector, shape [num_nodes, latent_dim]
|
|
metrics.json val/test AUC & AP, all hyperparameters
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
import sys
|
|
|
|
import numpy as np
|
|
import torch
|
|
from torch_geometric.nn.models import VGAE
|
|
from torch_geometric.transforms import RandomLinkSplit
|
|
from torch_geometric.utils import (
|
|
negative_sampling,
|
|
remove_self_loops,
|
|
to_undirected,
|
|
)
|
|
from sklearn.metrics import average_precision_score, roc_auc_score
|
|
|
|
from chromatin_gnn.model import build_encoder
|
|
|
|
|
|
@torch.no_grad()
|
|
def _eval_linkpred(z, pos_edges, neg_edges):
|
|
def _score(edges):
|
|
src, dst = edges
|
|
return torch.sigmoid((z[src] * z[dst]).sum(dim=1)).cpu().numpy()
|
|
y_true = np.concatenate([np.ones(pos_edges.size(1)),
|
|
np.zeros(neg_edges.size(1))])
|
|
y_pred = np.concatenate([_score(pos_edges), _score(neg_edges)])
|
|
return roc_auc_score(y_true, y_pred), average_precision_score(y_true, y_pred)
|
|
|
|
|
|
def main():
|
|
ap = argparse.ArgumentParser(
|
|
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
|
|
)
|
|
# Data
|
|
ap.add_argument("--graph", required=True,
|
|
help="Path to Data .pt file from build_graph.py")
|
|
# Architecture
|
|
ap.add_argument("--encoder", default="gat",
|
|
choices=["gcn", "gat", "deep_gcn"],
|
|
help="Encoder architecture (default: gat)")
|
|
ap.add_argument("--hidden", type=int, default=128)
|
|
ap.add_argument("--latent", type=int, default=64)
|
|
ap.add_argument("--heads", type=int, default=4,
|
|
help="Number of attention heads (GAT only)")
|
|
ap.add_argument("--dropout", type=float, default=0.3)
|
|
# Training
|
|
ap.add_argument("--epochs", type=int, default=500)
|
|
ap.add_argument("--patience", type=int, default=50,
|
|
help="Early-stopping patience (val-AUC epochs)")
|
|
ap.add_argument("--lr", type=float, default=3e-4)
|
|
ap.add_argument("--beta", type=float, default=1.0,
|
|
help="Final KL weight in the ELBO (β-VGAE). "
|
|
"Values < 1 reduce regularisation on sparse graphs.")
|
|
ap.add_argument("--kl_anneal",type=int, default=100,
|
|
help="Linearly warm up KL weight from 0 → beta over "
|
|
"this many epochs (0 = no annealing).")
|
|
ap.add_argument("--seed", type=int, default=42)
|
|
ap.add_argument("--outdir", default="results")
|
|
ap.add_argument("--device", default="auto",
|
|
help="cuda | cpu | auto (default: auto-detect CUDA)")
|
|
ap.add_argument("--constant_features", action="store_true",
|
|
help="Replace node features with a constant ones vector "
|
|
"(topology-only ablation: removes feature signal so the "
|
|
"encoder must rely on graph structure alone).")
|
|
args = ap.parse_args()
|
|
|
|
if args.device == "auto":
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
else:
|
|
device = torch.device(args.device)
|
|
print(f"Using device: {device}"
|
|
+ (f" ({torch.cuda.get_device_name(0)})" if device.type == "cuda" else ""))
|
|
|
|
torch.manual_seed(args.seed)
|
|
np.random.seed(args.seed)
|
|
if device.type == "cuda":
|
|
torch.cuda.manual_seed_all(args.seed)
|
|
os.makedirs(args.outdir, exist_ok=True)
|
|
|
|
# ---- Load and clean graph ----
|
|
data = torch.load(args.graph, weights_only=False)
|
|
ei, _ = remove_self_loops(data.edge_index)
|
|
# Propagate edge_weight through to_undirected so split weights stay aligned
|
|
if hasattr(data, "edge_weight") and data.edge_weight is not None:
|
|
ei, ew = to_undirected(ei, data.edge_weight, num_nodes=data.num_nodes)
|
|
data.edge_weight = ew
|
|
else:
|
|
ei = to_undirected(ei, num_nodes=data.num_nodes)
|
|
data.edge_index = ei
|
|
x = data.x.float()
|
|
if args.constant_features:
|
|
x = torch.ones(data.num_nodes, 1)
|
|
print("[ablation] Replacing node features with constant ones "
|
|
"(topology-only mode).")
|
|
print(f"Graph: {data.num_nodes} nodes "
|
|
f"{data.edge_index.shape[1]} edges "
|
|
f"{x.shape[1]} node features")
|
|
print(f"Encoder: {args.encoder} hidden={args.hidden} latent={args.latent}"
|
|
+ (f" heads={args.heads}" if args.encoder == "gat" else ""))
|
|
|
|
# ---- Edge splits ----
|
|
splitter = RandomLinkSplit(
|
|
num_val=0.1, num_test=0.1,
|
|
is_undirected=True,
|
|
add_negative_train_samples=False,
|
|
split_labels=False,
|
|
)
|
|
train_data, val_data, test_data = splitter(data)
|
|
train_data.pos_edge_index = train_data.edge_index
|
|
train_ew = getattr(train_data, "edge_weight", None)
|
|
full_ew = getattr(data, "edge_weight", None)
|
|
|
|
for split in (val_data, test_data):
|
|
split.pos_edge_index = split.edge_index
|
|
split.neg_edge_index = negative_sampling(
|
|
edge_index=split.edge_index,
|
|
num_nodes=data.num_nodes,
|
|
num_neg_samples=split.edge_index.size(1),
|
|
method="sparse",
|
|
)
|
|
|
|
# ---- Move tensors to device ----
|
|
x = x.to(device)
|
|
train_edge_index = train_data.edge_index.to(device)
|
|
train_pos_edge = train_data.pos_edge_index.to(device)
|
|
if train_ew is not None: train_ew = train_ew.to(device)
|
|
full_edge_index = data.edge_index.to(device)
|
|
if full_ew is not None: full_ew = full_ew.to(device)
|
|
val_pos = val_data.pos_edge_index.to(device)
|
|
val_neg = val_data.neg_edge_index.to(device)
|
|
test_pos = test_data.pos_edge_index.to(device)
|
|
test_neg = test_data.neg_edge_index.to(device)
|
|
|
|
# ---- Model ----
|
|
enc = build_encoder(args.encoder, in_dim=x.size(1),
|
|
hidden=args.hidden, latent=args.latent,
|
|
dropout=args.dropout, heads=args.heads)
|
|
model = VGAE(enc).to(device)
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr,
|
|
weight_decay=1e-5)
|
|
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
|
optimizer, mode="max", factor=0.5, patience=15, verbose=False
|
|
)
|
|
|
|
# ---- Training loop with β-VGAE KL warm-up ----
|
|
best_val_auc = -1.0
|
|
best_state = None
|
|
no_improve = 0
|
|
epochs_ran = 0
|
|
|
|
for epoch in range(1, args.epochs + 1):
|
|
# Linear KL warm-up: β rises from 0 to args.beta over kl_anneal epochs
|
|
if args.kl_anneal > 0:
|
|
kl_w = min(args.beta, args.beta * epoch / args.kl_anneal)
|
|
else:
|
|
kl_w = args.beta
|
|
|
|
model.train()
|
|
optimizer.zero_grad()
|
|
z = model.encode(x, train_edge_index, train_ew)
|
|
loss = (model.recon_loss(z, train_pos_edge)
|
|
+ (kl_w / data.num_nodes) * model.kl_loss())
|
|
loss.backward()
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
|
optimizer.step()
|
|
|
|
model.eval()
|
|
with torch.no_grad():
|
|
z_full = model.encode(x, full_edge_index, full_ew)
|
|
val_auc, val_ap = _eval_linkpred(z_full, val_pos, val_neg)
|
|
|
|
scheduler.step(val_auc)
|
|
|
|
if val_auc > best_val_auc:
|
|
best_val_auc = val_auc
|
|
best_state = {k: v.cpu().clone()
|
|
for k, v in model.state_dict().items()}
|
|
no_improve = 0
|
|
else:
|
|
no_improve += 1
|
|
|
|
epochs_ran = epoch
|
|
if epoch % 20 == 0 or epoch == 1:
|
|
lr_now = optimizer.param_groups[0]["lr"]
|
|
print(f"[{epoch:03d}/{args.epochs}] "
|
|
f"loss={loss.item():.4f} kl_w={kl_w:.3f} "
|
|
f"val AUC={val_auc:.4f} AP={val_ap:.4f} lr={lr_now:.2e}")
|
|
|
|
if no_improve >= args.patience:
|
|
print(f"Early stopping at epoch {epoch} "
|
|
f"(no val-AUC improvement for {args.patience} epochs)")
|
|
break
|
|
|
|
# ---- Restore best and test ----
|
|
model.load_state_dict({k: v.to(device) for k, v in best_state.items()})
|
|
model.eval()
|
|
with torch.no_grad():
|
|
z_final = model.encode(x, full_edge_index, full_ew)
|
|
test_auc, test_ap = _eval_linkpred(z_final, test_pos, test_neg)
|
|
|
|
# ---- Save outputs ----
|
|
torch.save(best_state, os.path.join(args.outdir, "model.pt"))
|
|
np.save(os.path.join(args.outdir, "emb.npy"), z_final.cpu().numpy())
|
|
|
|
metrics = {
|
|
"encoder": args.encoder,
|
|
"val_auc": float(best_val_auc),
|
|
"test_auc": float(test_auc),
|
|
"test_ap": float(test_ap),
|
|
"epochs_ran": epochs_ran,
|
|
"epochs_max": args.epochs,
|
|
"patience": args.patience,
|
|
"beta": args.beta,
|
|
"kl_anneal": args.kl_anneal,
|
|
"hidden": args.hidden,
|
|
"latent": args.latent,
|
|
"heads": args.heads if args.encoder == "gat" else None,
|
|
"dropout": args.dropout,
|
|
"lr": args.lr,
|
|
"seed": args.seed,
|
|
"constant_features": bool(args.constant_features),
|
|
"in_features": int(x.size(1)),
|
|
}
|
|
with open(os.path.join(args.outdir, "metrics.json"), "w") as f:
|
|
json.dump(metrics, f, indent=2)
|
|
|
|
print(f"\nSaved → {args.outdir}/")
|
|
print(f"Embeddings shape: {z_final.shape}")
|
|
print(f"Test AUC={test_auc:.4f} AP={test_ap:.4f} "
|
|
f"(val best={best_val_auc:.4f})")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|