#!/usr/bin/env python3 """ Encode a chromatin contact graph using a trained VGAE model. Dimensions (in_dim, hidden, latent) and encoder type are read from the metrics.json saved alongside model.pt. Edge weights are passed to GCN-based encoders so the same weighted message-passing used during training is applied at inference time. Usage ----- python scripts/encode_graph.py \\ --model results/GM12878/model.pt \\ --graph data/processed/IMR90_chr21.pt \\ --out results/IMR90/emb.npy """ import argparse import json import os import sys import numpy as np import torch from torch_geometric.nn.models import VGAE from torch_geometric.utils import remove_self_loops, to_undirected from chromatin_gnn.model import build_encoder, GCNEncoder def _load_metrics(model_path: str) -> dict: """Read metrics.json from the same directory as model.pt.""" metrics_path = os.path.join(os.path.dirname(os.path.abspath(model_path)), "metrics.json") if os.path.exists(metrics_path): with open(metrics_path) as f: return json.load(f) return {} def _infer_dims_gcn(state_dict: dict) -> tuple: """Fallback: infer (in_dim, hidden, latent) from a GCN state_dict.""" keys = list(state_dict.keys()) def _first_weight(substr): for k in keys: if (substr in k and "weight" in k and "running" not in k and "num_batches" not in k): return state_dict[k].shape raise KeyError(f"No weight key containing '{substr}'. Keys: {keys}") gc1_shape = _first_weight("gc1") gc_mu_shape = _first_weight("gc_mu") hidden = gc1_shape[0] latent = gc_mu_shape[0] for k in keys: if "norm" in k and k.endswith("weight") and "running" not in k: in_dim = state_dict[k].shape[0] break else: in_dim = gc1_shape[1] return in_dim, hidden, latent def main(): p = argparse.ArgumentParser( description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter ) p.add_argument("--model", required=True, help="Path to model.pt saved by train_vgae.py") p.add_argument("--graph", required=True, help="Path to Data .pt file from build_graph.py") p.add_argument("--out", required=True, help="Output .npy path for node embeddings") p.add_argument("--device", default="auto", help="cuda | cpu | auto (default: auto-detect CUDA)") args = p.parse_args() if args.device == "auto": device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: device = torch.device(args.device) print(f"Using device: {device}" + (f" ({torch.cuda.get_device_name(0)})" if device.type == "cuda" else "")) data = torch.load(args.graph, weights_only=False) state_dict = torch.load(args.model, map_location=device, weights_only=False) # Build edge_index and edge_weight (undirected, consistent with training) ei, _ = remove_self_loops(data.edge_index) if hasattr(data, "edge_weight") and data.edge_weight is not None: ei, ew = to_undirected(ei, data.edge_weight, num_nodes=data.num_nodes) else: ei = to_undirected(ei, num_nodes=data.num_nodes) ew = None # Read hyperparameters from metrics.json (preferred) or infer from state_dict metrics = _load_metrics(args.model) encoder_name = metrics.get("encoder", "gcn") hidden = metrics.get("hidden") latent = metrics.get("latent") heads = metrics.get("heads") or 4 if hidden is None or latent is None: print("metrics.json not found or incomplete — inferring dims from state_dict (GCN only)") _, hidden, latent = _infer_dims_gcn(state_dict) encoder_name = "gcn" in_dim = data.x.shape[1] print(f"Encoder: {encoder_name} in_dim={in_dim} hidden={hidden} latent={latent}") enc = build_encoder(encoder_name, in_dim=in_dim, hidden=hidden, latent=latent, dropout=0.0, heads=heads) model = VGAE(enc).to(device) model.load_state_dict(state_dict) model.eval() x = data.x.float().to(device) ei = ei.to(device) if ew is not None: ew = ew.to(device) with torch.no_grad(): z = model.encode(x, ei, ew) os.makedirs(os.path.dirname(os.path.abspath(args.out)), exist_ok=True) np.save(args.out, z.cpu().numpy()) print(f"Saved embeddings → {args.out} shape={z.shape}") if __name__ == "__main__": main()