v1.0.0: VGAE applied to GM12878 vs IMR90 chr21 Hi-C at 25kb
Full reproducible pipeline: .mcool + ChIP-seq bigwigs → latent embeddings → A/B compartment calls → cross-cell comparison. Key results (chr21, 25 kb, latent dim=32): - Test AUC=0.777, AP=0.759 (converged epoch 31/300) - GM12878 A/B silhouette (cosine) = 0.775 - IMR90 zero-shot silhouette = 0.443 - A-compartment bins stable across cell types (mean cosine Δ=0.042) - B-compartment bins shift substantially (mean cosine Δ=0.451) - 101 B→A and 70 A→B compartment switches GM12878→IMR90
This commit is contained in:
@@ -1,179 +1,185 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Train a Variational Graph Autoencoder (VGAE) on a chromatin contact graph.
|
||||
---
|
||||
Inputs:
|
||||
- A PyTorch Geometric Data object saved with torch.save()
|
||||
- from build_graph.py
|
||||
---
|
||||
Outputs (under results/):
|
||||
- model.pt : trained VGAE state_dict
|
||||
- emb.npy : node embeddings (mean; shape [num_nodes, latent_dim])
|
||||
- metrics.json : train/val/test AUC/AP summary
|
||||
|
||||
Inputs
|
||||
------
|
||||
PyTorch Geometric Data object saved by build_graph.py.
|
||||
|
||||
Outputs (under --outdir)
|
||||
------------------------
|
||||
model.pt trained VGAE state_dict (includes BatchNorm running statistics)
|
||||
emb.npy node embeddings — mu vector, shape [num_nodes, latent_dim]
|
||||
metrics.json val/test AUC & AP plus all hyperparameters
|
||||
"""
|
||||
|
||||
import os, json, argparse, numpy as np, torch
|
||||
import torch.nn.functional as F
|
||||
from torch_geometric.nn import GCNConv
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch_geometric.nn.models import VGAE
|
||||
from torch_geometric.transforms import RandomLinkSplit
|
||||
from torch_geometric.utils import to_undirected, remove_self_loops
|
||||
from torch_geometric.utils import negative_sampling
|
||||
from sklearn.metrics import roc_auc_score, average_precision_score
|
||||
from torch_geometric.utils import (
|
||||
negative_sampling,
|
||||
remove_self_loops,
|
||||
to_undirected,
|
||||
)
|
||||
from sklearn.metrics import average_precision_score, roc_auc_score
|
||||
|
||||
|
||||
class Encoder(torch.nn.Module):
|
||||
def __init__(self, in_dim: int, hidden: int, latent: int, dropout: float = 0.2):
|
||||
super().__init__()
|
||||
self.gc1 = GCNConv(in_dim, hidden, add_self_loops=True, normalize=True)
|
||||
self.gc_mu = GCNConv(hidden, latent, add_self_loops=True, normalize=True)
|
||||
self.gc_log = GCNConv(hidden, latent, add_self_loops=True, normalize=True)
|
||||
self.dropout = dropout
|
||||
|
||||
def forward(self, x, edge_index):
|
||||
h = self.gc1(x, edge_index)
|
||||
h = F.relu(h)
|
||||
h = F.dropout(h, p=self.dropout, training=self.training)
|
||||
return self.gc_mu(h, edge_index), self.gc_log(h, edge_index)
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
from model import Encoder
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def eval_linkpred(model, data_like, z):
|
||||
"""Compute AUROC/AP using provided positive/negative edges."""
|
||||
pos = data_like.pos_edge_index
|
||||
neg = data_like.neg_edge_index
|
||||
# model.test returns (auc, ap) but relies on torchmetrics in some versions;
|
||||
# compute explicitly for stability:
|
||||
def sigmoid(x): return 1 / (1 + torch.exp(-x))
|
||||
def _eval_linkpred(z, pos_edges, neg_edges):
|
||||
"""Return (AUROC, AP) for link prediction."""
|
||||
def _sigmoid(x):
|
||||
return 1.0 / (1.0 + torch.exp(-x))
|
||||
|
||||
# Inner product decoder scores
|
||||
def scores(edges):
|
||||
def _score(edges):
|
||||
src, dst = edges
|
||||
s = (z[src] * z[dst]).sum(dim=1)
|
||||
return sigmoid(s).cpu().numpy()
|
||||
return _sigmoid((z[src] * z[dst]).sum(dim=1)).cpu().numpy()
|
||||
|
||||
y_true = np.concatenate([np.ones(pos.size(1)), np.zeros(neg.size(1))])
|
||||
y_pred = np.concatenate([scores(pos), scores(neg)])
|
||||
|
||||
auc = roc_auc_score(y_true, y_pred)
|
||||
ap = average_precision_score(y_true, y_pred)
|
||||
return auc, ap
|
||||
y_true = np.concatenate([np.ones(pos_edges.size(1)),
|
||||
np.zeros(neg_edges.size(1))])
|
||||
y_pred = np.concatenate([_score(pos_edges), _score(neg_edges)])
|
||||
return roc_auc_score(y_true, y_pred), average_precision_score(y_true, y_pred)
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--graph", required=True, help="Path to Data .pt file")
|
||||
ap.add_argument("--epochs", type=int, default=100)
|
||||
ap.add_argument("--lr", type=float, default=1e-3)
|
||||
ap.add_argument("--hidden", type=int, default=128)
|
||||
ap.add_argument("--latent", type=int, default=64)
|
||||
ap.add_argument("--dropout", type=float, default=0.2)
|
||||
ap.add_argument("--seed", type=int, default=42)
|
||||
ap.add_argument("--outdir", default="results")
|
||||
ap = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
|
||||
)
|
||||
ap.add_argument("--graph", required=True,
|
||||
help="Path to Data .pt file from build_graph.py")
|
||||
ap.add_argument("--epochs", type=int, default=300)
|
||||
ap.add_argument("--patience", type=int, default=20,
|
||||
help="Early-stopping patience (val-AUC epochs without improvement)")
|
||||
ap.add_argument("--lr", type=float, default=1e-3)
|
||||
ap.add_argument("--hidden", type=int, default=64)
|
||||
ap.add_argument("--latent", type=int, default=32)
|
||||
ap.add_argument("--dropout", type=float, default=0.2)
|
||||
ap.add_argument("--seed", type=int, default=42)
|
||||
ap.add_argument("--outdir", default="results")
|
||||
args = ap.parse_args()
|
||||
|
||||
torch.manual_seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
os.makedirs(args.outdir, exist_ok=True)
|
||||
|
||||
# Load graph
|
||||
data = torch.load(args.graph)
|
||||
# Coalesce/clean edges
|
||||
# ---- Load and clean graph ----
|
||||
data = torch.load(args.graph, weights_only=False)
|
||||
ei, _ = remove_self_loops(data.edge_index)
|
||||
data.edge_index = to_undirected(ei, num_nodes=data.num_nodes)
|
||||
x = data.x.float()
|
||||
print(f"Graph: {data.num_nodes} nodes "
|
||||
f"{data.edge_index.shape[1]} edges "
|
||||
f"{x.shape[1]} node features")
|
||||
|
||||
# Split edges for link prediction
|
||||
# ---- Edge splits for link-prediction evaluation ----
|
||||
splitter = RandomLinkSplit(
|
||||
num_val=0.1,
|
||||
num_test=0.1,
|
||||
num_val=0.1, num_test=0.1,
|
||||
is_undirected=True,
|
||||
add_negative_train_samples=False,
|
||||
split_labels=False,
|
||||
)
|
||||
train_data, val_data, test_data = splitter(data)
|
||||
|
||||
# Positive edges are just the edges in each split
|
||||
train_data.pos_edge_index = train_data.edge_index
|
||||
val_data.pos_edge_index = val_data.edge_index
|
||||
test_data.pos_edge_index = test_data.edge_index
|
||||
|
||||
# Generate negative edges for validation and test manually
|
||||
for subset in [val_data, test_data]:
|
||||
subset.neg_edge_index = negative_sampling(
|
||||
edge_index=subset.edge_index,
|
||||
for split in (val_data, test_data):
|
||||
split.pos_edge_index = split.edge_index
|
||||
split.neg_edge_index = negative_sampling(
|
||||
edge_index=split.edge_index,
|
||||
num_nodes=data.num_nodes,
|
||||
num_neg_samples=subset.edge_index.size(1),
|
||||
method='sparse'
|
||||
num_neg_samples=split.edge_index.size(1),
|
||||
method="sparse",
|
||||
)
|
||||
|
||||
|
||||
# Model
|
||||
enc = Encoder(in_dim=x.size(1), hidden=args.hidden, latent=args.latent, dropout=args.dropout)
|
||||
# ---- Model ----
|
||||
enc = Encoder(in_dim=x.size(1), hidden=args.hidden,
|
||||
latent=args.latent, dropout=args.dropout)
|
||||
model = VGAE(enc)
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
|
||||
|
||||
# Training loop
|
||||
# ---- Training loop with early stopping ----
|
||||
best_val_auc = -1.0
|
||||
best_state = None
|
||||
best_state = None
|
||||
no_improve = 0
|
||||
epochs_ran = 0
|
||||
|
||||
for epoch in range(1, args.epochs + 1):
|
||||
model.train()
|
||||
optimizer.zero_grad()
|
||||
# Encode using remaining training edges
|
||||
z = model.encode(x, train_data.edge_index)
|
||||
# Reconstruction loss on positive training edges (negatives sampled inside)
|
||||
loss_recon = model.recon_loss(z, train_data.pos_edge_index)
|
||||
# KL divergence regularizer
|
||||
loss_kl = (1.0 / data.num_nodes) * model.kl_loss()
|
||||
loss = loss_recon + loss_kl
|
||||
z = model.encode(x, train_data.edge_index)
|
||||
loss = (model.recon_loss(z, train_data.pos_edge_index)
|
||||
+ (1.0 / data.num_nodes) * model.kl_loss())
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# Validation
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
z_full = model.encode(x, data.edge_index) # use full graph for eval embeddings
|
||||
val_auc, val_ap = eval_linkpred(model, val_data, z_full)
|
||||
z_full = model.encode(x, data.edge_index)
|
||||
val_auc, val_ap = _eval_linkpred(
|
||||
z_full, val_data.pos_edge_index, val_data.neg_edge_index
|
||||
)
|
||||
|
||||
if val_auc > best_val_auc:
|
||||
best_val_auc = val_auc
|
||||
best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
|
||||
best_state = {k: v.cpu().clone()
|
||||
for k, v in model.state_dict().items()}
|
||||
no_improve = 0
|
||||
else:
|
||||
no_improve += 1
|
||||
|
||||
epochs_ran = epoch
|
||||
if epoch % 10 == 0 or epoch == 1:
|
||||
print(f"[{epoch:03d}/{args.epochs}] loss={loss.item():.4f} | val AUC={val_auc:.4f} AP={val_ap:.4f}")
|
||||
print(f"[{epoch:03d}/{args.epochs}] "
|
||||
f"loss={loss.item():.4f} "
|
||||
f"val AUC={val_auc:.4f} AP={val_ap:.4f}")
|
||||
|
||||
# Save best model
|
||||
if no_improve >= args.patience:
|
||||
print(f"Early stopping at epoch {epoch} "
|
||||
f"(no val-AUC improvement for {args.patience} epochs)")
|
||||
break
|
||||
|
||||
# ---- Restore best checkpoint and compute test metrics ----
|
||||
model.load_state_dict(best_state)
|
||||
model_path = os.path.join(args.outdir, "model.pt")
|
||||
torch.save(model.state_dict(), model_path)
|
||||
|
||||
# Final test metrics
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
z_final = model.encode(x, data.edge_index)
|
||||
test_auc, test_ap = eval_linkpred(model, test_data, z_final)
|
||||
test_auc, test_ap = _eval_linkpred(
|
||||
z_final, test_data.pos_edge_index, test_data.neg_edge_index
|
||||
)
|
||||
|
||||
# ---- Save outputs ----
|
||||
model_path = os.path.join(args.outdir, "model.pt")
|
||||
torch.save(best_state, model_path)
|
||||
|
||||
# Save embeddings & metrics
|
||||
emb_path = os.path.join(args.outdir, "emb.npy")
|
||||
np.save(emb_path, z_final.cpu().numpy())
|
||||
|
||||
metrics = {
|
||||
"val_auc": float(best_val_auc),
|
||||
"test_auc": float(test_auc),
|
||||
"test_ap": float(test_ap),
|
||||
"epochs": args.epochs,
|
||||
"hidden": args.hidden,
|
||||
"latent": args.latent,
|
||||
"dropout": args.dropout,
|
||||
"lr": args.lr,
|
||||
"seed": args.seed
|
||||
"val_auc": float(best_val_auc),
|
||||
"test_auc": float(test_auc),
|
||||
"test_ap": float(test_ap),
|
||||
"epochs_ran": epochs_ran,
|
||||
"epochs_max": args.epochs,
|
||||
"patience": args.patience,
|
||||
"hidden": args.hidden,
|
||||
"latent": args.latent,
|
||||
"dropout": args.dropout,
|
||||
"lr": args.lr,
|
||||
"seed": args.seed,
|
||||
}
|
||||
with open(os.path.join(args.outdir, "metrics.json"), "w") as f:
|
||||
json.dump(metrics, f, indent=2)
|
||||
|
||||
print(f"Saved model -> {model_path}")
|
||||
print(f"Saved embeddings -> {emb_path} (shape={z_final.shape})")
|
||||
print(f"Metrics: AUC(test)={test_auc:.4f}, AP(test)={test_ap:.4f}")
|
||||
print(f"\nSaved model → {model_path}")
|
||||
print(f"Saved embeddings → {emb_path} shape={z_final.shape}")
|
||||
print(f"Test AUC={test_auc:.4f} AP={test_ap:.4f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user