Files
chromatin-vgae-hic/scripts/encode_graph.py
aman acadbd780c v1.0.0: VGAE applied to GM12878 vs IMR90 chr21 Hi-C at 25kb
Full reproducible pipeline: .mcool + ChIP-seq bigwigs → latent
  embeddings → A/B compartment calls → cross-cell comparison.

  Key results (chr21, 25 kb, latent dim=32):
  - Test AUC=0.777, AP=0.759 (converged epoch 31/300)
  - GM12878 A/B silhouette (cosine) = 0.775
  - IMR90 zero-shot silhouette = 0.443
  - A-compartment bins stable across cell types (mean cosine Δ=0.042)
  - B-compartment bins shift substantially (mean cosine Δ=0.451)
  - 101 B→A and 70 A→B compartment switches GM12878→IMR90
2026-05-15 01:53:04 +02:00

92 lines
2.9 KiB
Python

#!/usr/bin/env python3
"""
Encode a chromatin contact graph using a trained VGAE model.
Dimensions (in_dim, hidden, latent) are inferred automatically from the saved
state_dict. The BatchNorm running statistics from training are restored, so the
same normalisation is applied to held-out cell lines without a separate scaler.
Usage
-----
python scripts/encode_graph.py \\
--model results/GM12878/model.pt \\
--graph data/processed/IMR90_chr21.pt \\
--out results/IMR90/emb.npy
"""
import argparse
import os
import sys
import numpy as np
import torch
from torch_geometric.nn.models import VGAE
sys.path.insert(0, os.path.dirname(__file__))
from model import Encoder
def _infer_dims(state_dict: dict) -> tuple:
"""Infer (in_dim, hidden, latent) from a VGAE state_dict."""
keys = list(state_dict.keys())
def _first_weight(substr):
for k in keys:
if (substr in k
and "weight" in k
and "running" not in k
and "num_batches" not in k):
return state_dict[k].shape
raise KeyError(f"No weight key containing '{substr}' in state_dict. "
f"Available keys: {keys}")
gc1_shape = _first_weight("gc1") # shape [hidden, in_dim]
gc_mu_shape = _first_weight("gc_mu") # shape [latent, hidden]
hidden = gc1_shape[0]
latent = gc_mu_shape[0]
# in_dim from BatchNorm weight (shape [in_dim])
for k in keys:
if "norm" in k and k.endswith("weight") and "running" not in k:
in_dim = state_dict[k].shape[0]
break
else:
in_dim = gc1_shape[1] # fallback: second dim of gc1 weight
return in_dim, hidden, latent
def main():
p = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
)
p.add_argument("--model", required=True,
help="Path to model.pt saved by train_vgae.py")
p.add_argument("--graph", required=True,
help="Path to Data .pt file from build_graph.py")
p.add_argument("--out", required=True,
help="Output .npy path for node embeddings")
args = p.parse_args()
data = torch.load(args.graph, weights_only=False)
state_dict = torch.load(args.model, map_location="cpu", weights_only=False)
in_dim, hidden, latent = _infer_dims(state_dict)
print(f"Inferred: in_dim={in_dim} hidden={hidden} latent={latent}")
enc = Encoder(in_dim=in_dim, hidden=hidden, latent=latent)
model = VGAE(enc)
model.load_state_dict(state_dict)
model.eval()
with torch.no_grad():
z = model.encode(data.x.float(), data.edge_index)
os.makedirs(os.path.dirname(os.path.abspath(args.out)), exist_ok=True)
np.save(args.out, z.cpu().numpy())
print(f"Saved embeddings → {args.out} shape={z.shape}")
if __name__ == "__main__":
main()