Files
chromatin-vgae-hic/scripts/encode_graph.py
aman a081f29e12 fix: three silent bugs in training and inference
- Edge weights (ICE log1p) were computed but never passed to GCNConv
- encode_graph.py hardcoded GCNEncoder regardless of saved model type
- Inference graph lacked to_undirected + edge_weight pipeline from training
2026-05-15 03:37:20 +02:00

121 lines
4.0 KiB
Python

#!/usr/bin/env python3
"""
Encode a chromatin contact graph using a trained VGAE model.
Dimensions (in_dim, hidden, latent) and encoder type are read from the
metrics.json saved alongside model.pt. Edge weights are passed to GCN-based
encoders so the same weighted message-passing used during training is applied
at inference time.
Usage
-----
python scripts/encode_graph.py \\
--model results/GM12878/model.pt \\
--graph data/processed/IMR90_chr21.pt \\
--out results/IMR90/emb.npy
"""
import argparse
import json
import os
import sys
import numpy as np
import torch
from torch_geometric.nn.models import VGAE
from torch_geometric.utils import remove_self_loops, to_undirected
sys.path.insert(0, os.path.dirname(__file__))
from model import build_encoder, GCNEncoder
def _load_metrics(model_path: str) -> dict:
"""Read metrics.json from the same directory as model.pt."""
metrics_path = os.path.join(os.path.dirname(os.path.abspath(model_path)), "metrics.json")
if os.path.exists(metrics_path):
with open(metrics_path) as f:
return json.load(f)
return {}
def _infer_dims_gcn(state_dict: dict) -> tuple:
"""Fallback: infer (in_dim, hidden, latent) from a GCN state_dict."""
keys = list(state_dict.keys())
def _first_weight(substr):
for k in keys:
if (substr in k
and "weight" in k
and "running" not in k
and "num_batches" not in k):
return state_dict[k].shape
raise KeyError(f"No weight key containing '{substr}'. Keys: {keys}")
gc1_shape = _first_weight("gc1")
gc_mu_shape = _first_weight("gc_mu")
hidden = gc1_shape[0]
latent = gc_mu_shape[0]
for k in keys:
if "norm" in k and k.endswith("weight") and "running" not in k:
in_dim = state_dict[k].shape[0]
break
else:
in_dim = gc1_shape[1]
return in_dim, hidden, latent
def main():
p = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
)
p.add_argument("--model", required=True,
help="Path to model.pt saved by train_vgae.py")
p.add_argument("--graph", required=True,
help="Path to Data .pt file from build_graph.py")
p.add_argument("--out", required=True,
help="Output .npy path for node embeddings")
args = p.parse_args()
data = torch.load(args.graph, weights_only=False)
state_dict = torch.load(args.model, map_location="cpu", weights_only=False)
# Build edge_index and edge_weight (undirected, consistent with training)
ei, _ = remove_self_loops(data.edge_index)
if hasattr(data, "edge_weight") and data.edge_weight is not None:
ei, ew = to_undirected(ei, data.edge_weight, num_nodes=data.num_nodes)
else:
ei = to_undirected(ei, num_nodes=data.num_nodes)
ew = None
# Read hyperparameters from metrics.json (preferred) or infer from state_dict
metrics = _load_metrics(args.model)
encoder_name = metrics.get("encoder", "gcn")
hidden = metrics.get("hidden")
latent = metrics.get("latent")
heads = metrics.get("heads") or 4
if hidden is None or latent is None:
print("metrics.json not found or incomplete — inferring dims from state_dict (GCN only)")
_, hidden, latent = _infer_dims_gcn(state_dict)
encoder_name = "gcn"
in_dim = data.x.shape[1]
print(f"Encoder: {encoder_name} in_dim={in_dim} hidden={hidden} latent={latent}")
enc = build_encoder(encoder_name, in_dim=in_dim, hidden=hidden,
latent=latent, dropout=0.0, heads=heads)
model = VGAE(enc)
model.load_state_dict(state_dict)
model.eval()
with torch.no_grad():
z = model.encode(data.x.float(), ei, ew)
os.makedirs(os.path.dirname(os.path.abspath(args.out)), exist_ok=True)
np.save(args.out, z.cpu().numpy())
print(f"Saved embeddings → {args.out} shape={z.shape}")
if __name__ == "__main__":
main()