fix: three silent bugs in training and inference
- Edge weights (ICE log1p) were computed but never passed to GCNConv - encode_graph.py hardcoded GCNEncoder regardless of saved model type - Inference graph lacked to_undirected + edge_weight pipeline from training
This commit is contained in:
@@ -2,9 +2,10 @@
|
|||||||
"""
|
"""
|
||||||
Encode a chromatin contact graph using a trained VGAE model.
|
Encode a chromatin contact graph using a trained VGAE model.
|
||||||
|
|
||||||
Dimensions (in_dim, hidden, latent) are inferred automatically from the saved
|
Dimensions (in_dim, hidden, latent) and encoder type are read from the
|
||||||
state_dict. The BatchNorm running statistics from training are restored, so the
|
metrics.json saved alongside model.pt. Edge weights are passed to GCN-based
|
||||||
same normalisation is applied to held-out cell lines without a separate scaler.
|
encoders so the same weighted message-passing used during training is applied
|
||||||
|
at inference time.
|
||||||
|
|
||||||
Usage
|
Usage
|
||||||
-----
|
-----
|
||||||
@@ -15,19 +16,30 @@ Usage
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch_geometric.nn.models import VGAE
|
from torch_geometric.nn.models import VGAE
|
||||||
|
from torch_geometric.utils import remove_self_loops, to_undirected
|
||||||
|
|
||||||
sys.path.insert(0, os.path.dirname(__file__))
|
sys.path.insert(0, os.path.dirname(__file__))
|
||||||
from model import Encoder
|
from model import build_encoder, GCNEncoder
|
||||||
|
|
||||||
|
|
||||||
def _infer_dims(state_dict: dict) -> tuple:
|
def _load_metrics(model_path: str) -> dict:
|
||||||
"""Infer (in_dim, hidden, latent) from a VGAE state_dict."""
|
"""Read metrics.json from the same directory as model.pt."""
|
||||||
|
metrics_path = os.path.join(os.path.dirname(os.path.abspath(model_path)), "metrics.json")
|
||||||
|
if os.path.exists(metrics_path):
|
||||||
|
with open(metrics_path) as f:
|
||||||
|
return json.load(f)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def _infer_dims_gcn(state_dict: dict) -> tuple:
|
||||||
|
"""Fallback: infer (in_dim, hidden, latent) from a GCN state_dict."""
|
||||||
keys = list(state_dict.keys())
|
keys = list(state_dict.keys())
|
||||||
|
|
||||||
def _first_weight(substr):
|
def _first_weight(substr):
|
||||||
@@ -37,22 +49,18 @@ def _infer_dims(state_dict: dict) -> tuple:
|
|||||||
and "running" not in k
|
and "running" not in k
|
||||||
and "num_batches" not in k):
|
and "num_batches" not in k):
|
||||||
return state_dict[k].shape
|
return state_dict[k].shape
|
||||||
raise KeyError(f"No weight key containing '{substr}' in state_dict. "
|
raise KeyError(f"No weight key containing '{substr}'. Keys: {keys}")
|
||||||
f"Available keys: {keys}")
|
|
||||||
|
|
||||||
gc1_shape = _first_weight("gc1") # shape [hidden, in_dim]
|
gc1_shape = _first_weight("gc1")
|
||||||
gc_mu_shape = _first_weight("gc_mu") # shape [latent, hidden]
|
gc_mu_shape = _first_weight("gc_mu")
|
||||||
hidden = gc1_shape[0]
|
hidden = gc1_shape[0]
|
||||||
latent = gc_mu_shape[0]
|
latent = gc_mu_shape[0]
|
||||||
|
|
||||||
# in_dim from BatchNorm weight (shape [in_dim])
|
|
||||||
for k in keys:
|
for k in keys:
|
||||||
if "norm" in k and k.endswith("weight") and "running" not in k:
|
if "norm" in k and k.endswith("weight") and "running" not in k:
|
||||||
in_dim = state_dict[k].shape[0]
|
in_dim = state_dict[k].shape[0]
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
in_dim = gc1_shape[1] # fallback: second dim of gc1 weight
|
in_dim = gc1_shape[1]
|
||||||
|
|
||||||
return in_dim, hidden, latent
|
return in_dim, hidden, latent
|
||||||
|
|
||||||
|
|
||||||
@@ -71,16 +79,37 @@ def main():
|
|||||||
data = torch.load(args.graph, weights_only=False)
|
data = torch.load(args.graph, weights_only=False)
|
||||||
state_dict = torch.load(args.model, map_location="cpu", weights_only=False)
|
state_dict = torch.load(args.model, map_location="cpu", weights_only=False)
|
||||||
|
|
||||||
in_dim, hidden, latent = _infer_dims(state_dict)
|
# Build edge_index and edge_weight (undirected, consistent with training)
|
||||||
print(f"Inferred: in_dim={in_dim} hidden={hidden} latent={latent}")
|
ei, _ = remove_self_loops(data.edge_index)
|
||||||
|
if hasattr(data, "edge_weight") and data.edge_weight is not None:
|
||||||
|
ei, ew = to_undirected(ei, data.edge_weight, num_nodes=data.num_nodes)
|
||||||
|
else:
|
||||||
|
ei = to_undirected(ei, num_nodes=data.num_nodes)
|
||||||
|
ew = None
|
||||||
|
|
||||||
enc = Encoder(in_dim=in_dim, hidden=hidden, latent=latent)
|
# Read hyperparameters from metrics.json (preferred) or infer from state_dict
|
||||||
|
metrics = _load_metrics(args.model)
|
||||||
|
encoder_name = metrics.get("encoder", "gcn")
|
||||||
|
hidden = metrics.get("hidden")
|
||||||
|
latent = metrics.get("latent")
|
||||||
|
heads = metrics.get("heads") or 4
|
||||||
|
|
||||||
|
if hidden is None or latent is None:
|
||||||
|
print("metrics.json not found or incomplete — inferring dims from state_dict (GCN only)")
|
||||||
|
_, hidden, latent = _infer_dims_gcn(state_dict)
|
||||||
|
encoder_name = "gcn"
|
||||||
|
|
||||||
|
in_dim = data.x.shape[1]
|
||||||
|
print(f"Encoder: {encoder_name} in_dim={in_dim} hidden={hidden} latent={latent}")
|
||||||
|
|
||||||
|
enc = build_encoder(encoder_name, in_dim=in_dim, hidden=hidden,
|
||||||
|
latent=latent, dropout=0.0, heads=heads)
|
||||||
model = VGAE(enc)
|
model = VGAE(enc)
|
||||||
model.load_state_dict(state_dict)
|
model.load_state_dict(state_dict)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
z = model.encode(data.x.float(), data.edge_index)
|
z = model.encode(data.x.float(), ei, ew)
|
||||||
|
|
||||||
os.makedirs(os.path.dirname(os.path.abspath(args.out)), exist_ok=True)
|
os.makedirs(os.path.dirname(os.path.abspath(args.out)), exist_ok=True)
|
||||||
np.save(args.out, z.cpu().numpy())
|
np.save(args.out, z.cpu().numpy())
|
||||||
|
|||||||
143
scripts/model.py
143
scripts/model.py
@@ -1,23 +1,34 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""Shared VGAE encoder. Imported by train_vgae.py and encode_graph.py."""
|
"""
|
||||||
|
VGAE encoder architectures for chromatin contact graphs.
|
||||||
|
|
||||||
|
Exported symbols
|
||||||
|
----------------
|
||||||
|
GCNEncoder — original 2-layer GCN (kept for backward compatibility)
|
||||||
|
GATEncoder — 2-layer GATv2 with multi-head attention
|
||||||
|
DeepGCNEncoder — 3-layer GCN with residual BatchNorm between layers
|
||||||
|
Encoder — alias for GCNEncoder (backward compat)
|
||||||
|
build_encoder() — factory: returns the right class from a string name
|
||||||
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch_geometric.nn import GCNConv
|
from torch_geometric.nn import GCNConv, GATv2Conv
|
||||||
|
|
||||||
|
|
||||||
class Encoder(nn.Module):
|
# ---------------------------------------------------------------------------
|
||||||
"""Two-layer GCN encoder for VGAE with input BatchNorm.
|
# GCN encoder (baseline)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
Architecture: BatchNorm → GCN(hidden) → ReLU → Dropout → GCN_mu / GCN_logstd
|
class GCNEncoder(nn.Module):
|
||||||
|
"""Two-layer GCN encoder with input BatchNorm.
|
||||||
|
|
||||||
The BatchNorm layer normalises raw ChIP-seq signals and its running statistics
|
Architecture: BatchNorm → GCNConv(hidden) → ReLU → Dropout
|
||||||
are saved in model.pt, so encode_graph.py applies identical normalisation to
|
→ GCNConv_mu / GCNConv_logstd
|
||||||
held-out cell lines without a separate scaler file.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_dim: int, hidden: int, latent: int, dropout: float = 0.2):
|
def __init__(self, in_dim: int, hidden: int, latent: int, dropout: float = 0.2, **_):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.norm = nn.BatchNorm1d(in_dim)
|
self.norm = nn.BatchNorm1d(in_dim)
|
||||||
self.gc1 = GCNConv(in_dim, hidden, add_self_loops=True, normalize=True)
|
self.gc1 = GCNConv(in_dim, hidden, add_self_loops=True, normalize=True)
|
||||||
@@ -25,8 +36,116 @@ class Encoder(nn.Module):
|
|||||||
self.gc_log = GCNConv(hidden, latent, add_self_loops=True, normalize=True)
|
self.gc_log = GCNConv(hidden, latent, add_self_loops=True, normalize=True)
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
|
|
||||||
def forward(self, x, edge_index):
|
def forward(self, x, edge_index, edge_weight=None):
|
||||||
x = self.norm(x)
|
x = self.norm(x)
|
||||||
h = F.relu(self.gc1(x, edge_index))
|
h = F.relu(self.gc1(x, edge_index, edge_weight))
|
||||||
h = F.dropout(h, p=self.dropout, training=self.training)
|
h = F.dropout(h, p=self.dropout, training=self.training)
|
||||||
return self.gc_mu(h, edge_index), self.gc_log(h, edge_index)
|
return self.gc_mu(h, edge_index, edge_weight), self.gc_log(h, edge_index, edge_weight)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# GAT encoder (preferred for Hi-C: handles degree heterogeneity via attention)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class GATEncoder(nn.Module):
|
||||||
|
"""Two-layer GATv2 encoder.
|
||||||
|
|
||||||
|
Each GATv2 layer applies multi-head attention, which lets the model
|
||||||
|
up-weight high-frequency contacts at TAD boundaries and CTCF anchors
|
||||||
|
rather than averaging all neighbours uniformly (as GCN does).
|
||||||
|
|
||||||
|
Architecture: BatchNorm → GATv2(hidden, heads) → ELU → BN → Dropout
|
||||||
|
→ GATv2(hidden, heads) → Dropout
|
||||||
|
→ GCNConv_mu / GCNConv_logstd
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, in_dim: int, hidden: int, latent: int,
|
||||||
|
heads: int = 4, dropout: float = 0.2, **_):
|
||||||
|
super().__init__()
|
||||||
|
if hidden % heads != 0:
|
||||||
|
raise ValueError(f"hidden ({hidden}) must be divisible by heads ({heads})")
|
||||||
|
self.norm = nn.BatchNorm1d(in_dim)
|
||||||
|
self.gat1 = GATv2Conv(in_dim, hidden // heads, heads=heads,
|
||||||
|
dropout=dropout, add_self_loops=True, concat=True)
|
||||||
|
self.bn1 = nn.BatchNorm1d(hidden)
|
||||||
|
self.gat2 = GATv2Conv(hidden, hidden // heads, heads=heads,
|
||||||
|
dropout=dropout, add_self_loops=True, concat=True)
|
||||||
|
self.gc_mu = GCNConv(hidden, latent, add_self_loops=True, normalize=True)
|
||||||
|
self.gc_log = GCNConv(hidden, latent, add_self_loops=True, normalize=True)
|
||||||
|
self.dropout = dropout
|
||||||
|
|
||||||
|
def forward(self, x, edge_index, edge_weight=None):
|
||||||
|
x = self.norm(x)
|
||||||
|
h = F.elu(self.gat1(x, edge_index))
|
||||||
|
h = self.bn1(h)
|
||||||
|
h = F.dropout(h, p=self.dropout, training=self.training)
|
||||||
|
h = F.elu(self.gat2(h, edge_index))
|
||||||
|
h = F.dropout(h, p=self.dropout, training=self.training)
|
||||||
|
# GATv2 learns its own attention weights; edge_weight is used only in the
|
||||||
|
# final linear projection layers (mu/log) where GCNConv accepts it.
|
||||||
|
return self.gc_mu(h, edge_index, edge_weight), self.gc_log(h, edge_index, edge_weight)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Deep GCN encoder (3 message-passing layers)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class DeepGCNEncoder(nn.Module):
|
||||||
|
"""Three-layer GCN encoder — wider receptive field than the baseline.
|
||||||
|
|
||||||
|
Architecture: BatchNorm → GCN1 → BN → ReLU → Dropout
|
||||||
|
→ GCN2 → ReLU → Dropout
|
||||||
|
→ GCNConv_mu / GCNConv_logstd
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, in_dim: int, hidden: int, latent: int, dropout: float = 0.2, **_):
|
||||||
|
super().__init__()
|
||||||
|
self.norm = nn.BatchNorm1d(in_dim)
|
||||||
|
self.gc1 = GCNConv(in_dim, hidden, add_self_loops=True, normalize=True)
|
||||||
|
self.bn1 = nn.BatchNorm1d(hidden)
|
||||||
|
self.gc2 = GCNConv(hidden, hidden, add_self_loops=True, normalize=True)
|
||||||
|
self.gc_mu = GCNConv(hidden, latent, add_self_loops=True, normalize=True)
|
||||||
|
self.gc_log = GCNConv(hidden, latent, add_self_loops=True, normalize=True)
|
||||||
|
self.dropout = dropout
|
||||||
|
|
||||||
|
def forward(self, x, edge_index, edge_weight=None):
|
||||||
|
x = self.norm(x)
|
||||||
|
h = F.relu(self.gc1(x, edge_index, edge_weight))
|
||||||
|
h = self.bn1(h)
|
||||||
|
h = F.dropout(h, p=self.dropout, training=self.training)
|
||||||
|
h = F.relu(self.gc2(h, edge_index, edge_weight))
|
||||||
|
h = F.dropout(h, p=self.dropout, training=self.training)
|
||||||
|
return self.gc_mu(h, edge_index, edge_weight), self.gc_log(h, edge_index, edge_weight)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Backward compatibility alias
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
Encoder = GCNEncoder
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Factory
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_ENCODERS = {
|
||||||
|
"gcn": GCNEncoder,
|
||||||
|
"gat": GATEncoder,
|
||||||
|
"deep_gcn": DeepGCNEncoder,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def build_encoder(name: str, in_dim: int, hidden: int, latent: int, **kwargs) -> nn.Module:
|
||||||
|
"""Instantiate an encoder by name.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
name : {"gcn", "gat", "deep_gcn"}
|
||||||
|
in_dim, hidden, latent : layer dimensions
|
||||||
|
**kwargs : passed to the constructor (e.g. dropout=0.3, heads=8)
|
||||||
|
"""
|
||||||
|
name = name.lower()
|
||||||
|
if name not in _ENCODERS:
|
||||||
|
raise ValueError(f"Unknown encoder '{name}'. Choose from {list(_ENCODERS)}")
|
||||||
|
return _ENCODERS[name](in_dim=in_dim, hidden=hidden, latent=latent, **kwargs)
|
||||||
|
|||||||
@@ -2,15 +2,25 @@
|
|||||||
"""
|
"""
|
||||||
Train a Variational Graph Autoencoder (VGAE) on a chromatin contact graph.
|
Train a Variational Graph Autoencoder (VGAE) on a chromatin contact graph.
|
||||||
|
|
||||||
|
Key improvements over v1
|
||||||
|
------------------------
|
||||||
|
• GAT / DeepGCN encoders selectable via --encoder
|
||||||
|
• β-VGAE with linear KL warm-up (--kl_anneal): lets the encoder learn
|
||||||
|
structure before the prior regularises the latent space
|
||||||
|
• Lower default LR (3e-4) and higher patience so the optimiser doesn't
|
||||||
|
overshoot the minimum
|
||||||
|
• --beta controls the final KL weight (default 1.0; try 0.5 to prevent
|
||||||
|
posterior collapse on sparse graphs)
|
||||||
|
|
||||||
Inputs
|
Inputs
|
||||||
------
|
------
|
||||||
PyTorch Geometric Data object saved by build_graph.py.
|
PyTorch Geometric Data object from build_graph.py
|
||||||
|
|
||||||
Outputs (under --outdir)
|
Outputs (under --outdir)
|
||||||
------------------------
|
------------------------
|
||||||
model.pt trained VGAE state_dict (includes BatchNorm running statistics)
|
model.pt trained state_dict (+ encoder type stored in metrics.json)
|
||||||
emb.npy node embeddings — mu vector, shape [num_nodes, latent_dim]
|
emb.npy node embeddings — mu vector, shape [num_nodes, latent_dim]
|
||||||
metrics.json val/test AUC & AP plus all hyperparameters
|
metrics.json val/test AUC & AP, all hyperparameters
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
@@ -30,19 +40,14 @@ from torch_geometric.utils import (
|
|||||||
from sklearn.metrics import average_precision_score, roc_auc_score
|
from sklearn.metrics import average_precision_score, roc_auc_score
|
||||||
|
|
||||||
sys.path.insert(0, os.path.dirname(__file__))
|
sys.path.insert(0, os.path.dirname(__file__))
|
||||||
from model import Encoder
|
from model import build_encoder
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def _eval_linkpred(z, pos_edges, neg_edges):
|
def _eval_linkpred(z, pos_edges, neg_edges):
|
||||||
"""Return (AUROC, AP) for link prediction."""
|
|
||||||
def _sigmoid(x):
|
|
||||||
return 1.0 / (1.0 + torch.exp(-x))
|
|
||||||
|
|
||||||
def _score(edges):
|
def _score(edges):
|
||||||
src, dst = edges
|
src, dst = edges
|
||||||
return _sigmoid((z[src] * z[dst]).sum(dim=1)).cpu().numpy()
|
return torch.sigmoid((z[src] * z[dst]).sum(dim=1)).cpu().numpy()
|
||||||
|
|
||||||
y_true = np.concatenate([np.ones(pos_edges.size(1)),
|
y_true = np.concatenate([np.ones(pos_edges.size(1)),
|
||||||
np.zeros(neg_edges.size(1))])
|
np.zeros(neg_edges.size(1))])
|
||||||
y_pred = np.concatenate([_score(pos_edges), _score(neg_edges)])
|
y_pred = np.concatenate([_score(pos_edges), _score(neg_edges)])
|
||||||
@@ -53,15 +58,29 @@ def main():
|
|||||||
ap = argparse.ArgumentParser(
|
ap = argparse.ArgumentParser(
|
||||||
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
|
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
|
||||||
)
|
)
|
||||||
|
# Data
|
||||||
ap.add_argument("--graph", required=True,
|
ap.add_argument("--graph", required=True,
|
||||||
help="Path to Data .pt file from build_graph.py")
|
help="Path to Data .pt file from build_graph.py")
|
||||||
ap.add_argument("--epochs", type=int, default=300)
|
# Architecture
|
||||||
ap.add_argument("--patience", type=int, default=20,
|
ap.add_argument("--encoder", default="gat",
|
||||||
help="Early-stopping patience (val-AUC epochs without improvement)")
|
choices=["gcn", "gat", "deep_gcn"],
|
||||||
ap.add_argument("--lr", type=float, default=1e-3)
|
help="Encoder architecture (default: gat)")
|
||||||
ap.add_argument("--hidden", type=int, default=64)
|
ap.add_argument("--hidden", type=int, default=128)
|
||||||
ap.add_argument("--latent", type=int, default=32)
|
ap.add_argument("--latent", type=int, default=64)
|
||||||
ap.add_argument("--dropout", type=float, default=0.2)
|
ap.add_argument("--heads", type=int, default=4,
|
||||||
|
help="Number of attention heads (GAT only)")
|
||||||
|
ap.add_argument("--dropout", type=float, default=0.3)
|
||||||
|
# Training
|
||||||
|
ap.add_argument("--epochs", type=int, default=500)
|
||||||
|
ap.add_argument("--patience", type=int, default=50,
|
||||||
|
help="Early-stopping patience (val-AUC epochs)")
|
||||||
|
ap.add_argument("--lr", type=float, default=3e-4)
|
||||||
|
ap.add_argument("--beta", type=float, default=1.0,
|
||||||
|
help="Final KL weight in the ELBO (β-VGAE). "
|
||||||
|
"Values < 1 reduce regularisation on sparse graphs.")
|
||||||
|
ap.add_argument("--kl_anneal",type=int, default=100,
|
||||||
|
help="Linearly warm up KL weight from 0 → beta over "
|
||||||
|
"this many epochs (0 = no annealing).")
|
||||||
ap.add_argument("--seed", type=int, default=42)
|
ap.add_argument("--seed", type=int, default=42)
|
||||||
ap.add_argument("--outdir", default="results")
|
ap.add_argument("--outdir", default="results")
|
||||||
args = ap.parse_args()
|
args = ap.parse_args()
|
||||||
@@ -73,13 +92,21 @@ def main():
|
|||||||
# ---- Load and clean graph ----
|
# ---- Load and clean graph ----
|
||||||
data = torch.load(args.graph, weights_only=False)
|
data = torch.load(args.graph, weights_only=False)
|
||||||
ei, _ = remove_self_loops(data.edge_index)
|
ei, _ = remove_self_loops(data.edge_index)
|
||||||
data.edge_index = to_undirected(ei, num_nodes=data.num_nodes)
|
# Propagate edge_weight through to_undirected so split weights stay aligned
|
||||||
|
if hasattr(data, "edge_weight") and data.edge_weight is not None:
|
||||||
|
ei, ew = to_undirected(ei, data.edge_weight, num_nodes=data.num_nodes)
|
||||||
|
data.edge_weight = ew
|
||||||
|
else:
|
||||||
|
ei = to_undirected(ei, num_nodes=data.num_nodes)
|
||||||
|
data.edge_index = ei
|
||||||
x = data.x.float()
|
x = data.x.float()
|
||||||
print(f"Graph: {data.num_nodes} nodes "
|
print(f"Graph: {data.num_nodes} nodes "
|
||||||
f"{data.edge_index.shape[1]} edges "
|
f"{data.edge_index.shape[1]} edges "
|
||||||
f"{x.shape[1]} node features")
|
f"{x.shape[1]} node features")
|
||||||
|
print(f"Encoder: {args.encoder} hidden={args.hidden} latent={args.latent}"
|
||||||
|
+ (f" heads={args.heads}" if args.encoder == "gat" else ""))
|
||||||
|
|
||||||
# ---- Edge splits for link-prediction evaluation ----
|
# ---- Edge splits ----
|
||||||
splitter = RandomLinkSplit(
|
splitter = RandomLinkSplit(
|
||||||
num_val=0.1, num_test=0.1,
|
num_val=0.1, num_test=0.1,
|
||||||
is_undirected=True,
|
is_undirected=True,
|
||||||
@@ -88,6 +115,8 @@ def main():
|
|||||||
)
|
)
|
||||||
train_data, val_data, test_data = splitter(data)
|
train_data, val_data, test_data = splitter(data)
|
||||||
train_data.pos_edge_index = train_data.edge_index
|
train_data.pos_edge_index = train_data.edge_index
|
||||||
|
train_ew = getattr(train_data, "edge_weight", None)
|
||||||
|
full_ew = getattr(data, "edge_weight", None)
|
||||||
|
|
||||||
for split in (val_data, test_data):
|
for split in (val_data, test_data):
|
||||||
split.pos_edge_index = split.edge_index
|
split.pos_edge_index = split.edge_index
|
||||||
@@ -99,33 +128,47 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# ---- Model ----
|
# ---- Model ----
|
||||||
enc = Encoder(in_dim=x.size(1), hidden=args.hidden,
|
enc = build_encoder(args.encoder, in_dim=x.size(1),
|
||||||
latent=args.latent, dropout=args.dropout)
|
hidden=args.hidden, latent=args.latent,
|
||||||
|
dropout=args.dropout, heads=args.heads)
|
||||||
model = VGAE(enc)
|
model = VGAE(enc)
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
|
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr,
|
||||||
|
weight_decay=1e-5)
|
||||||
|
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||||
|
optimizer, mode="max", factor=0.5, patience=15, verbose=False
|
||||||
|
)
|
||||||
|
|
||||||
# ---- Training loop with early stopping ----
|
# ---- Training loop with β-VGAE KL warm-up ----
|
||||||
best_val_auc = -1.0
|
best_val_auc = -1.0
|
||||||
best_state = None
|
best_state = None
|
||||||
no_improve = 0
|
no_improve = 0
|
||||||
epochs_ran = 0
|
epochs_ran = 0
|
||||||
|
|
||||||
for epoch in range(1, args.epochs + 1):
|
for epoch in range(1, args.epochs + 1):
|
||||||
|
# Linear KL warm-up: β rises from 0 to args.beta over kl_anneal epochs
|
||||||
|
if args.kl_anneal > 0:
|
||||||
|
kl_w = min(args.beta, args.beta * epoch / args.kl_anneal)
|
||||||
|
else:
|
||||||
|
kl_w = args.beta
|
||||||
|
|
||||||
model.train()
|
model.train()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
z = model.encode(x, train_data.edge_index)
|
z = model.encode(x, train_data.edge_index, train_ew)
|
||||||
loss = (model.recon_loss(z, train_data.pos_edge_index)
|
loss = (model.recon_loss(z, train_data.pos_edge_index)
|
||||||
+ (1.0 / data.num_nodes) * model.kl_loss())
|
+ (kl_w / data.num_nodes) * model.kl_loss())
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
z_full = model.encode(x, data.edge_index)
|
z_full = model.encode(x, data.edge_index, full_ew)
|
||||||
val_auc, val_ap = _eval_linkpred(
|
val_auc, val_ap = _eval_linkpred(
|
||||||
z_full, val_data.pos_edge_index, val_data.neg_edge_index
|
z_full, val_data.pos_edge_index, val_data.neg_edge_index
|
||||||
)
|
)
|
||||||
|
|
||||||
|
scheduler.step(val_auc)
|
||||||
|
|
||||||
if val_auc > best_val_auc:
|
if val_auc > best_val_auc:
|
||||||
best_val_auc = val_auc
|
best_val_auc = val_auc
|
||||||
best_state = {k: v.cpu().clone()
|
best_state = {k: v.cpu().clone()
|
||||||
@@ -135,41 +178,43 @@ def main():
|
|||||||
no_improve += 1
|
no_improve += 1
|
||||||
|
|
||||||
epochs_ran = epoch
|
epochs_ran = epoch
|
||||||
if epoch % 10 == 0 or epoch == 1:
|
if epoch % 20 == 0 or epoch == 1:
|
||||||
|
lr_now = optimizer.param_groups[0]["lr"]
|
||||||
print(f"[{epoch:03d}/{args.epochs}] "
|
print(f"[{epoch:03d}/{args.epochs}] "
|
||||||
f"loss={loss.item():.4f} "
|
f"loss={loss.item():.4f} kl_w={kl_w:.3f} "
|
||||||
f"val AUC={val_auc:.4f} AP={val_ap:.4f}")
|
f"val AUC={val_auc:.4f} AP={val_ap:.4f} lr={lr_now:.2e}")
|
||||||
|
|
||||||
if no_improve >= args.patience:
|
if no_improve >= args.patience:
|
||||||
print(f"Early stopping at epoch {epoch} "
|
print(f"Early stopping at epoch {epoch} "
|
||||||
f"(no val-AUC improvement for {args.patience} epochs)")
|
f"(no val-AUC improvement for {args.patience} epochs)")
|
||||||
break
|
break
|
||||||
|
|
||||||
# ---- Restore best checkpoint and compute test metrics ----
|
# ---- Restore best and test ----
|
||||||
model.load_state_dict(best_state)
|
model.load_state_dict(best_state)
|
||||||
model.eval()
|
model.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
z_final = model.encode(x, data.edge_index)
|
z_final = model.encode(x, data.edge_index, full_ew)
|
||||||
test_auc, test_ap = _eval_linkpred(
|
test_auc, test_ap = _eval_linkpred(
|
||||||
z_final, test_data.pos_edge_index, test_data.neg_edge_index
|
z_final, test_data.pos_edge_index, test_data.neg_edge_index
|
||||||
)
|
)
|
||||||
|
|
||||||
# ---- Save outputs ----
|
# ---- Save outputs ----
|
||||||
model_path = os.path.join(args.outdir, "model.pt")
|
torch.save(best_state, os.path.join(args.outdir, "model.pt"))
|
||||||
torch.save(best_state, model_path)
|
np.save(os.path.join(args.outdir, "emb.npy"), z_final.cpu().numpy())
|
||||||
|
|
||||||
emb_path = os.path.join(args.outdir, "emb.npy")
|
|
||||||
np.save(emb_path, z_final.cpu().numpy())
|
|
||||||
|
|
||||||
metrics = {
|
metrics = {
|
||||||
|
"encoder": args.encoder,
|
||||||
"val_auc": float(best_val_auc),
|
"val_auc": float(best_val_auc),
|
||||||
"test_auc": float(test_auc),
|
"test_auc": float(test_auc),
|
||||||
"test_ap": float(test_ap),
|
"test_ap": float(test_ap),
|
||||||
"epochs_ran": epochs_ran,
|
"epochs_ran": epochs_ran,
|
||||||
"epochs_max": args.epochs,
|
"epochs_max": args.epochs,
|
||||||
"patience": args.patience,
|
"patience": args.patience,
|
||||||
|
"beta": args.beta,
|
||||||
|
"kl_anneal": args.kl_anneal,
|
||||||
"hidden": args.hidden,
|
"hidden": args.hidden,
|
||||||
"latent": args.latent,
|
"latent": args.latent,
|
||||||
|
"heads": args.heads if args.encoder == "gat" else None,
|
||||||
"dropout": args.dropout,
|
"dropout": args.dropout,
|
||||||
"lr": args.lr,
|
"lr": args.lr,
|
||||||
"seed": args.seed,
|
"seed": args.seed,
|
||||||
@@ -177,9 +222,10 @@ def main():
|
|||||||
with open(os.path.join(args.outdir, "metrics.json"), "w") as f:
|
with open(os.path.join(args.outdir, "metrics.json"), "w") as f:
|
||||||
json.dump(metrics, f, indent=2)
|
json.dump(metrics, f, indent=2)
|
||||||
|
|
||||||
print(f"\nSaved model → {model_path}")
|
print(f"\nSaved → {args.outdir}/")
|
||||||
print(f"Saved embeddings → {emb_path} shape={z_final.shape}")
|
print(f"Embeddings shape: {z_final.shape}")
|
||||||
print(f"Test AUC={test_auc:.4f} AP={test_ap:.4f}")
|
print(f"Test AUC={test_auc:.4f} AP={test_ap:.4f} "
|
||||||
|
f"(val best={best_val_auc:.4f})")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user