fix: three silent bugs in training and inference

- Edge weights (ICE log1p) were computed but never passed to GCNConv
- encode_graph.py hardcoded GCNEncoder regardless of saved model type
- Inference graph lacked to_undirected + edge_weight pipeline from training
This commit is contained in:
2026-05-15 03:37:20 +02:00
parent acadbd780c
commit a081f29e12
3 changed files with 263 additions and 69 deletions

View File

@@ -2,9 +2,10 @@
""" """
Encode a chromatin contact graph using a trained VGAE model. Encode a chromatin contact graph using a trained VGAE model.
Dimensions (in_dim, hidden, latent) are inferred automatically from the saved Dimensions (in_dim, hidden, latent) and encoder type are read from the
state_dict. The BatchNorm running statistics from training are restored, so the metrics.json saved alongside model.pt. Edge weights are passed to GCN-based
same normalisation is applied to held-out cell lines without a separate scaler. encoders so the same weighted message-passing used during training is applied
at inference time.
Usage Usage
----- -----
@@ -15,19 +16,30 @@ Usage
""" """
import argparse import argparse
import json
import os import os
import sys import sys
import numpy as np import numpy as np
import torch import torch
from torch_geometric.nn.models import VGAE from torch_geometric.nn.models import VGAE
from torch_geometric.utils import remove_self_loops, to_undirected
sys.path.insert(0, os.path.dirname(__file__)) sys.path.insert(0, os.path.dirname(__file__))
from model import Encoder from model import build_encoder, GCNEncoder
def _infer_dims(state_dict: dict) -> tuple: def _load_metrics(model_path: str) -> dict:
"""Infer (in_dim, hidden, latent) from a VGAE state_dict.""" """Read metrics.json from the same directory as model.pt."""
metrics_path = os.path.join(os.path.dirname(os.path.abspath(model_path)), "metrics.json")
if os.path.exists(metrics_path):
with open(metrics_path) as f:
return json.load(f)
return {}
def _infer_dims_gcn(state_dict: dict) -> tuple:
"""Fallback: infer (in_dim, hidden, latent) from a GCN state_dict."""
keys = list(state_dict.keys()) keys = list(state_dict.keys())
def _first_weight(substr): def _first_weight(substr):
@@ -37,22 +49,18 @@ def _infer_dims(state_dict: dict) -> tuple:
and "running" not in k and "running" not in k
and "num_batches" not in k): and "num_batches" not in k):
return state_dict[k].shape return state_dict[k].shape
raise KeyError(f"No weight key containing '{substr}' in state_dict. " raise KeyError(f"No weight key containing '{substr}'. Keys: {keys}")
f"Available keys: {keys}")
gc1_shape = _first_weight("gc1") # shape [hidden, in_dim] gc1_shape = _first_weight("gc1")
gc_mu_shape = _first_weight("gc_mu") # shape [latent, hidden] gc_mu_shape = _first_weight("gc_mu")
hidden = gc1_shape[0] hidden = gc1_shape[0]
latent = gc_mu_shape[0] latent = gc_mu_shape[0]
# in_dim from BatchNorm weight (shape [in_dim])
for k in keys: for k in keys:
if "norm" in k and k.endswith("weight") and "running" not in k: if "norm" in k and k.endswith("weight") and "running" not in k:
in_dim = state_dict[k].shape[0] in_dim = state_dict[k].shape[0]
break break
else: else:
in_dim = gc1_shape[1] # fallback: second dim of gc1 weight in_dim = gc1_shape[1]
return in_dim, hidden, latent return in_dim, hidden, latent
@@ -71,16 +79,37 @@ def main():
data = torch.load(args.graph, weights_only=False) data = torch.load(args.graph, weights_only=False)
state_dict = torch.load(args.model, map_location="cpu", weights_only=False) state_dict = torch.load(args.model, map_location="cpu", weights_only=False)
in_dim, hidden, latent = _infer_dims(state_dict) # Build edge_index and edge_weight (undirected, consistent with training)
print(f"Inferred: in_dim={in_dim} hidden={hidden} latent={latent}") ei, _ = remove_self_loops(data.edge_index)
if hasattr(data, "edge_weight") and data.edge_weight is not None:
ei, ew = to_undirected(ei, data.edge_weight, num_nodes=data.num_nodes)
else:
ei = to_undirected(ei, num_nodes=data.num_nodes)
ew = None
enc = Encoder(in_dim=in_dim, hidden=hidden, latent=latent) # Read hyperparameters from metrics.json (preferred) or infer from state_dict
metrics = _load_metrics(args.model)
encoder_name = metrics.get("encoder", "gcn")
hidden = metrics.get("hidden")
latent = metrics.get("latent")
heads = metrics.get("heads") or 4
if hidden is None or latent is None:
print("metrics.json not found or incomplete — inferring dims from state_dict (GCN only)")
_, hidden, latent = _infer_dims_gcn(state_dict)
encoder_name = "gcn"
in_dim = data.x.shape[1]
print(f"Encoder: {encoder_name} in_dim={in_dim} hidden={hidden} latent={latent}")
enc = build_encoder(encoder_name, in_dim=in_dim, hidden=hidden,
latent=latent, dropout=0.0, heads=heads)
model = VGAE(enc) model = VGAE(enc)
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
z = model.encode(data.x.float(), data.edge_index) z = model.encode(data.x.float(), ei, ew)
os.makedirs(os.path.dirname(os.path.abspath(args.out)), exist_ok=True) os.makedirs(os.path.dirname(os.path.abspath(args.out)), exist_ok=True)
np.save(args.out, z.cpu().numpy()) np.save(args.out, z.cpu().numpy())

View File

@@ -1,23 +1,34 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
"""Shared VGAE encoder. Imported by train_vgae.py and encode_graph.py.""" """
VGAE encoder architectures for chromatin contact graphs.
Exported symbols
----------------
GCNEncoder — original 2-layer GCN (kept for backward compatibility)
GATEncoder — 2-layer GATv2 with multi-head attention
DeepGCNEncoder — 3-layer GCN with residual BatchNorm between layers
Encoder — alias for GCNEncoder (backward compat)
build_encoder() — factory: returns the right class from a string name
"""
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch_geometric.nn import GCNConv from torch_geometric.nn import GCNConv, GATv2Conv
class Encoder(nn.Module): # ---------------------------------------------------------------------------
"""Two-layer GCN encoder for VGAE with input BatchNorm. # GCN encoder (baseline)
# ---------------------------------------------------------------------------
Architecture: BatchNorm → GCN(hidden) → ReLU → Dropout → GCN_mu / GCN_logstd class GCNEncoder(nn.Module):
"""Two-layer GCN encoder with input BatchNorm.
The BatchNorm layer normalises raw ChIP-seq signals and its running statistics Architecture: BatchNorm → GCNConv(hidden) → ReLU → Dropout
are saved in model.pt, so encode_graph.py applies identical normalisation to → GCNConv_mu / GCNConv_logstd
held-out cell lines without a separate scaler file.
""" """
def __init__(self, in_dim: int, hidden: int, latent: int, dropout: float = 0.2): def __init__(self, in_dim: int, hidden: int, latent: int, dropout: float = 0.2, **_):
super().__init__() super().__init__()
self.norm = nn.BatchNorm1d(in_dim) self.norm = nn.BatchNorm1d(in_dim)
self.gc1 = GCNConv(in_dim, hidden, add_self_loops=True, normalize=True) self.gc1 = GCNConv(in_dim, hidden, add_self_loops=True, normalize=True)
@@ -25,8 +36,116 @@ class Encoder(nn.Module):
self.gc_log = GCNConv(hidden, latent, add_self_loops=True, normalize=True) self.gc_log = GCNConv(hidden, latent, add_self_loops=True, normalize=True)
self.dropout = dropout self.dropout = dropout
def forward(self, x, edge_index): def forward(self, x, edge_index, edge_weight=None):
x = self.norm(x) x = self.norm(x)
h = F.relu(self.gc1(x, edge_index)) h = F.relu(self.gc1(x, edge_index, edge_weight))
h = F.dropout(h, p=self.dropout, training=self.training) h = F.dropout(h, p=self.dropout, training=self.training)
return self.gc_mu(h, edge_index), self.gc_log(h, edge_index) return self.gc_mu(h, edge_index, edge_weight), self.gc_log(h, edge_index, edge_weight)
# ---------------------------------------------------------------------------
# GAT encoder (preferred for Hi-C: handles degree heterogeneity via attention)
# ---------------------------------------------------------------------------
class GATEncoder(nn.Module):
"""Two-layer GATv2 encoder.
Each GATv2 layer applies multi-head attention, which lets the model
up-weight high-frequency contacts at TAD boundaries and CTCF anchors
rather than averaging all neighbours uniformly (as GCN does).
Architecture: BatchNorm → GATv2(hidden, heads) → ELU → BN → Dropout
→ GATv2(hidden, heads) → Dropout
→ GCNConv_mu / GCNConv_logstd
"""
def __init__(self, in_dim: int, hidden: int, latent: int,
heads: int = 4, dropout: float = 0.2, **_):
super().__init__()
if hidden % heads != 0:
raise ValueError(f"hidden ({hidden}) must be divisible by heads ({heads})")
self.norm = nn.BatchNorm1d(in_dim)
self.gat1 = GATv2Conv(in_dim, hidden // heads, heads=heads,
dropout=dropout, add_self_loops=True, concat=True)
self.bn1 = nn.BatchNorm1d(hidden)
self.gat2 = GATv2Conv(hidden, hidden // heads, heads=heads,
dropout=dropout, add_self_loops=True, concat=True)
self.gc_mu = GCNConv(hidden, latent, add_self_loops=True, normalize=True)
self.gc_log = GCNConv(hidden, latent, add_self_loops=True, normalize=True)
self.dropout = dropout
def forward(self, x, edge_index, edge_weight=None):
x = self.norm(x)
h = F.elu(self.gat1(x, edge_index))
h = self.bn1(h)
h = F.dropout(h, p=self.dropout, training=self.training)
h = F.elu(self.gat2(h, edge_index))
h = F.dropout(h, p=self.dropout, training=self.training)
# GATv2 learns its own attention weights; edge_weight is used only in the
# final linear projection layers (mu/log) where GCNConv accepts it.
return self.gc_mu(h, edge_index, edge_weight), self.gc_log(h, edge_index, edge_weight)
# ---------------------------------------------------------------------------
# Deep GCN encoder (3 message-passing layers)
# ---------------------------------------------------------------------------
class DeepGCNEncoder(nn.Module):
"""Three-layer GCN encoder — wider receptive field than the baseline.
Architecture: BatchNorm → GCN1 → BN → ReLU → Dropout
→ GCN2 → ReLU → Dropout
→ GCNConv_mu / GCNConv_logstd
"""
def __init__(self, in_dim: int, hidden: int, latent: int, dropout: float = 0.2, **_):
super().__init__()
self.norm = nn.BatchNorm1d(in_dim)
self.gc1 = GCNConv(in_dim, hidden, add_self_loops=True, normalize=True)
self.bn1 = nn.BatchNorm1d(hidden)
self.gc2 = GCNConv(hidden, hidden, add_self_loops=True, normalize=True)
self.gc_mu = GCNConv(hidden, latent, add_self_loops=True, normalize=True)
self.gc_log = GCNConv(hidden, latent, add_self_loops=True, normalize=True)
self.dropout = dropout
def forward(self, x, edge_index, edge_weight=None):
x = self.norm(x)
h = F.relu(self.gc1(x, edge_index, edge_weight))
h = self.bn1(h)
h = F.dropout(h, p=self.dropout, training=self.training)
h = F.relu(self.gc2(h, edge_index, edge_weight))
h = F.dropout(h, p=self.dropout, training=self.training)
return self.gc_mu(h, edge_index, edge_weight), self.gc_log(h, edge_index, edge_weight)
# ---------------------------------------------------------------------------
# Backward compatibility alias
# ---------------------------------------------------------------------------
Encoder = GCNEncoder
# ---------------------------------------------------------------------------
# Factory
# ---------------------------------------------------------------------------
_ENCODERS = {
"gcn": GCNEncoder,
"gat": GATEncoder,
"deep_gcn": DeepGCNEncoder,
}
def build_encoder(name: str, in_dim: int, hidden: int, latent: int, **kwargs) -> nn.Module:
"""Instantiate an encoder by name.
Parameters
----------
name : {"gcn", "gat", "deep_gcn"}
in_dim, hidden, latent : layer dimensions
**kwargs : passed to the constructor (e.g. dropout=0.3, heads=8)
"""
name = name.lower()
if name not in _ENCODERS:
raise ValueError(f"Unknown encoder '{name}'. Choose from {list(_ENCODERS)}")
return _ENCODERS[name](in_dim=in_dim, hidden=hidden, latent=latent, **kwargs)

View File

@@ -2,15 +2,25 @@
""" """
Train a Variational Graph Autoencoder (VGAE) on a chromatin contact graph. Train a Variational Graph Autoencoder (VGAE) on a chromatin contact graph.
Key improvements over v1
------------------------
• GAT / DeepGCN encoders selectable via --encoder
• β-VGAE with linear KL warm-up (--kl_anneal): lets the encoder learn
structure before the prior regularises the latent space
• Lower default LR (3e-4) and higher patience so the optimiser doesn't
overshoot the minimum
• --beta controls the final KL weight (default 1.0; try 0.5 to prevent
posterior collapse on sparse graphs)
Inputs Inputs
------ ------
PyTorch Geometric Data object saved by build_graph.py. PyTorch Geometric Data object from build_graph.py
Outputs (under --outdir) Outputs (under --outdir)
------------------------ ------------------------
model.pt trained VGAE state_dict (includes BatchNorm running statistics) model.pt trained state_dict (+ encoder type stored in metrics.json)
emb.npy node embeddings — mu vector, shape [num_nodes, latent_dim] emb.npy node embeddings — mu vector, shape [num_nodes, latent_dim]
metrics.json val/test AUC & AP plus all hyperparameters metrics.json val/test AUC & AP, all hyperparameters
""" """
import argparse import argparse
@@ -30,19 +40,14 @@ from torch_geometric.utils import (
from sklearn.metrics import average_precision_score, roc_auc_score from sklearn.metrics import average_precision_score, roc_auc_score
sys.path.insert(0, os.path.dirname(__file__)) sys.path.insert(0, os.path.dirname(__file__))
from model import Encoder from model import build_encoder
@torch.no_grad() @torch.no_grad()
def _eval_linkpred(z, pos_edges, neg_edges): def _eval_linkpred(z, pos_edges, neg_edges):
"""Return (AUROC, AP) for link prediction."""
def _sigmoid(x):
return 1.0 / (1.0 + torch.exp(-x))
def _score(edges): def _score(edges):
src, dst = edges src, dst = edges
return _sigmoid((z[src] * z[dst]).sum(dim=1)).cpu().numpy() return torch.sigmoid((z[src] * z[dst]).sum(dim=1)).cpu().numpy()
y_true = np.concatenate([np.ones(pos_edges.size(1)), y_true = np.concatenate([np.ones(pos_edges.size(1)),
np.zeros(neg_edges.size(1))]) np.zeros(neg_edges.size(1))])
y_pred = np.concatenate([_score(pos_edges), _score(neg_edges)]) y_pred = np.concatenate([_score(pos_edges), _score(neg_edges)])
@@ -53,15 +58,29 @@ def main():
ap = argparse.ArgumentParser( ap = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
) )
# Data
ap.add_argument("--graph", required=True, ap.add_argument("--graph", required=True,
help="Path to Data .pt file from build_graph.py") help="Path to Data .pt file from build_graph.py")
ap.add_argument("--epochs", type=int, default=300) # Architecture
ap.add_argument("--patience", type=int, default=20, ap.add_argument("--encoder", default="gat",
help="Early-stopping patience (val-AUC epochs without improvement)") choices=["gcn", "gat", "deep_gcn"],
ap.add_argument("--lr", type=float, default=1e-3) help="Encoder architecture (default: gat)")
ap.add_argument("--hidden", type=int, default=64) ap.add_argument("--hidden", type=int, default=128)
ap.add_argument("--latent", type=int, default=32) ap.add_argument("--latent", type=int, default=64)
ap.add_argument("--dropout", type=float, default=0.2) ap.add_argument("--heads", type=int, default=4,
help="Number of attention heads (GAT only)")
ap.add_argument("--dropout", type=float, default=0.3)
# Training
ap.add_argument("--epochs", type=int, default=500)
ap.add_argument("--patience", type=int, default=50,
help="Early-stopping patience (val-AUC epochs)")
ap.add_argument("--lr", type=float, default=3e-4)
ap.add_argument("--beta", type=float, default=1.0,
help="Final KL weight in the ELBO (β-VGAE). "
"Values < 1 reduce regularisation on sparse graphs.")
ap.add_argument("--kl_anneal",type=int, default=100,
help="Linearly warm up KL weight from 0 → beta over "
"this many epochs (0 = no annealing).")
ap.add_argument("--seed", type=int, default=42) ap.add_argument("--seed", type=int, default=42)
ap.add_argument("--outdir", default="results") ap.add_argument("--outdir", default="results")
args = ap.parse_args() args = ap.parse_args()
@@ -73,13 +92,21 @@ def main():
# ---- Load and clean graph ---- # ---- Load and clean graph ----
data = torch.load(args.graph, weights_only=False) data = torch.load(args.graph, weights_only=False)
ei, _ = remove_self_loops(data.edge_index) ei, _ = remove_self_loops(data.edge_index)
data.edge_index = to_undirected(ei, num_nodes=data.num_nodes) # Propagate edge_weight through to_undirected so split weights stay aligned
if hasattr(data, "edge_weight") and data.edge_weight is not None:
ei, ew = to_undirected(ei, data.edge_weight, num_nodes=data.num_nodes)
data.edge_weight = ew
else:
ei = to_undirected(ei, num_nodes=data.num_nodes)
data.edge_index = ei
x = data.x.float() x = data.x.float()
print(f"Graph: {data.num_nodes} nodes " print(f"Graph: {data.num_nodes} nodes "
f"{data.edge_index.shape[1]} edges " f"{data.edge_index.shape[1]} edges "
f"{x.shape[1]} node features") f"{x.shape[1]} node features")
print(f"Encoder: {args.encoder} hidden={args.hidden} latent={args.latent}"
+ (f" heads={args.heads}" if args.encoder == "gat" else ""))
# ---- Edge splits for link-prediction evaluation ---- # ---- Edge splits ----
splitter = RandomLinkSplit( splitter = RandomLinkSplit(
num_val=0.1, num_test=0.1, num_val=0.1, num_test=0.1,
is_undirected=True, is_undirected=True,
@@ -88,6 +115,8 @@ def main():
) )
train_data, val_data, test_data = splitter(data) train_data, val_data, test_data = splitter(data)
train_data.pos_edge_index = train_data.edge_index train_data.pos_edge_index = train_data.edge_index
train_ew = getattr(train_data, "edge_weight", None)
full_ew = getattr(data, "edge_weight", None)
for split in (val_data, test_data): for split in (val_data, test_data):
split.pos_edge_index = split.edge_index split.pos_edge_index = split.edge_index
@@ -99,33 +128,47 @@ def main():
) )
# ---- Model ---- # ---- Model ----
enc = Encoder(in_dim=x.size(1), hidden=args.hidden, enc = build_encoder(args.encoder, in_dim=x.size(1),
latent=args.latent, dropout=args.dropout) hidden=args.hidden, latent=args.latent,
dropout=args.dropout, heads=args.heads)
model = VGAE(enc) model = VGAE(enc)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr,
weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="max", factor=0.5, patience=15, verbose=False
)
# ---- Training loop with early stopping ---- # ---- Training loop with β-VGAE KL warm-up ----
best_val_auc = -1.0 best_val_auc = -1.0
best_state = None best_state = None
no_improve = 0 no_improve = 0
epochs_ran = 0 epochs_ran = 0
for epoch in range(1, args.epochs + 1): for epoch in range(1, args.epochs + 1):
# Linear KL warm-up: β rises from 0 to args.beta over kl_anneal epochs
if args.kl_anneal > 0:
kl_w = min(args.beta, args.beta * epoch / args.kl_anneal)
else:
kl_w = args.beta
model.train() model.train()
optimizer.zero_grad() optimizer.zero_grad()
z = model.encode(x, train_data.edge_index) z = model.encode(x, train_data.edge_index, train_ew)
loss = (model.recon_loss(z, train_data.pos_edge_index) loss = (model.recon_loss(z, train_data.pos_edge_index)
+ (1.0 / data.num_nodes) * model.kl_loss()) + (kl_w / data.num_nodes) * model.kl_loss())
loss.backward() loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step() optimizer.step()
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
z_full = model.encode(x, data.edge_index) z_full = model.encode(x, data.edge_index, full_ew)
val_auc, val_ap = _eval_linkpred( val_auc, val_ap = _eval_linkpred(
z_full, val_data.pos_edge_index, val_data.neg_edge_index z_full, val_data.pos_edge_index, val_data.neg_edge_index
) )
scheduler.step(val_auc)
if val_auc > best_val_auc: if val_auc > best_val_auc:
best_val_auc = val_auc best_val_auc = val_auc
best_state = {k: v.cpu().clone() best_state = {k: v.cpu().clone()
@@ -135,41 +178,43 @@ def main():
no_improve += 1 no_improve += 1
epochs_ran = epoch epochs_ran = epoch
if epoch % 10 == 0 or epoch == 1: if epoch % 20 == 0 or epoch == 1:
lr_now = optimizer.param_groups[0]["lr"]
print(f"[{epoch:03d}/{args.epochs}] " print(f"[{epoch:03d}/{args.epochs}] "
f"loss={loss.item():.4f} " f"loss={loss.item():.4f} kl_w={kl_w:.3f} "
f"val AUC={val_auc:.4f} AP={val_ap:.4f}") f"val AUC={val_auc:.4f} AP={val_ap:.4f} lr={lr_now:.2e}")
if no_improve >= args.patience: if no_improve >= args.patience:
print(f"Early stopping at epoch {epoch} " print(f"Early stopping at epoch {epoch} "
f"(no val-AUC improvement for {args.patience} epochs)") f"(no val-AUC improvement for {args.patience} epochs)")
break break
# ---- Restore best checkpoint and compute test metrics ---- # ---- Restore best and test ----
model.load_state_dict(best_state) model.load_state_dict(best_state)
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
z_final = model.encode(x, data.edge_index) z_final = model.encode(x, data.edge_index, full_ew)
test_auc, test_ap = _eval_linkpred( test_auc, test_ap = _eval_linkpred(
z_final, test_data.pos_edge_index, test_data.neg_edge_index z_final, test_data.pos_edge_index, test_data.neg_edge_index
) )
# ---- Save outputs ---- # ---- Save outputs ----
model_path = os.path.join(args.outdir, "model.pt") torch.save(best_state, os.path.join(args.outdir, "model.pt"))
torch.save(best_state, model_path) np.save(os.path.join(args.outdir, "emb.npy"), z_final.cpu().numpy())
emb_path = os.path.join(args.outdir, "emb.npy")
np.save(emb_path, z_final.cpu().numpy())
metrics = { metrics = {
"encoder": args.encoder,
"val_auc": float(best_val_auc), "val_auc": float(best_val_auc),
"test_auc": float(test_auc), "test_auc": float(test_auc),
"test_ap": float(test_ap), "test_ap": float(test_ap),
"epochs_ran": epochs_ran, "epochs_ran": epochs_ran,
"epochs_max": args.epochs, "epochs_max": args.epochs,
"patience": args.patience, "patience": args.patience,
"beta": args.beta,
"kl_anneal": args.kl_anneal,
"hidden": args.hidden, "hidden": args.hidden,
"latent": args.latent, "latent": args.latent,
"heads": args.heads if args.encoder == "gat" else None,
"dropout": args.dropout, "dropout": args.dropout,
"lr": args.lr, "lr": args.lr,
"seed": args.seed, "seed": args.seed,
@@ -177,9 +222,10 @@ def main():
with open(os.path.join(args.outdir, "metrics.json"), "w") as f: with open(os.path.join(args.outdir, "metrics.json"), "w") as f:
json.dump(metrics, f, indent=2) json.dump(metrics, f, indent=2)
print(f"\nSaved model → {model_path}") print(f"\nSaved {args.outdir}/")
print(f"Saved embeddings {emb_path} shape={z_final.shape}") print(f"Embeddings shape: {z_final.shape}")
print(f"Test AUC={test_auc:.4f} AP={test_ap:.4f}") print(f"Test AUC={test_auc:.4f} AP={test_ap:.4f} "
f"(val best={best_val_auc:.4f})")
if __name__ == "__main__": if __name__ == "__main__":