#!/usr/bin/env python3 """ Encode a chromatin contact graph using a trained VGAE model. Dimensions (in_dim, hidden, latent) are inferred automatically from the saved state_dict. The BatchNorm running statistics from training are restored, so the same normalisation is applied to held-out cell lines without a separate scaler. Usage ----- python scripts/encode_graph.py \\ --model results/GM12878/model.pt \\ --graph data/processed/IMR90_chr21.pt \\ --out results/IMR90/emb.npy """ import argparse import os import sys import numpy as np import torch from torch_geometric.nn.models import VGAE sys.path.insert(0, os.path.dirname(__file__)) from model import Encoder def _infer_dims(state_dict: dict) -> tuple: """Infer (in_dim, hidden, latent) from a VGAE state_dict.""" keys = list(state_dict.keys()) def _first_weight(substr): for k in keys: if (substr in k and "weight" in k and "running" not in k and "num_batches" not in k): return state_dict[k].shape raise KeyError(f"No weight key containing '{substr}' in state_dict. " f"Available keys: {keys}") gc1_shape = _first_weight("gc1") # shape [hidden, in_dim] gc_mu_shape = _first_weight("gc_mu") # shape [latent, hidden] hidden = gc1_shape[0] latent = gc_mu_shape[0] # in_dim from BatchNorm weight (shape [in_dim]) for k in keys: if "norm" in k and k.endswith("weight") and "running" not in k: in_dim = state_dict[k].shape[0] break else: in_dim = gc1_shape[1] # fallback: second dim of gc1 weight return in_dim, hidden, latent def main(): p = argparse.ArgumentParser( description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter ) p.add_argument("--model", required=True, help="Path to model.pt saved by train_vgae.py") p.add_argument("--graph", required=True, help="Path to Data .pt file from build_graph.py") p.add_argument("--out", required=True, help="Output .npy path for node embeddings") args = p.parse_args() data = torch.load(args.graph, weights_only=False) state_dict = torch.load(args.model, map_location="cpu", weights_only=False) in_dim, hidden, latent = _infer_dims(state_dict) print(f"Inferred: in_dim={in_dim} hidden={hidden} latent={latent}") enc = Encoder(in_dim=in_dim, hidden=hidden, latent=latent) model = VGAE(enc) model.load_state_dict(state_dict) model.eval() with torch.no_grad(): z = model.encode(data.x.float(), data.edge_index) os.makedirs(os.path.dirname(os.path.abspath(args.out)), exist_ok=True) np.save(args.out, z.cpu().numpy()) print(f"Saved embeddings → {args.out} shape={z.shape}") if __name__ == "__main__": main()