v1.0.0: VGAE applied to GM12878 vs IMR90 chr21 Hi-C at 25kb
Full reproducible pipeline: .mcool + ChIP-seq bigwigs → latent embeddings → A/B compartment calls → cross-cell comparison. Key results (chr21, 25 kb, latent dim=32): - Test AUC=0.777, AP=0.759 (converged epoch 31/300) - GM12878 A/B silhouette (cosine) = 0.775 - IMR90 zero-shot silhouette = 0.443 - A-compartment bins stable across cell types (mean cosine Δ=0.042) - B-compartment bins shift substantially (mean cosine Δ=0.451) - 101 B→A and 70 A→B compartment switches GM12878→IMR90
This commit is contained in:
@@ -1,63 +1,91 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Encode a new graph using a trained VGAE model.
|
||||
Automatically infers hidden/latent dimensions from saved weights.
|
||||
Encode a chromatin contact graph using a trained VGAE model.
|
||||
|
||||
Dimensions (in_dim, hidden, latent) are inferred automatically from the saved
|
||||
state_dict. The BatchNorm running statistics from training are restored, so the
|
||||
same normalisation is applied to held-out cell lines without a separate scaler.
|
||||
|
||||
Usage
|
||||
-----
|
||||
python scripts/encode_graph.py \\
|
||||
--model results/GM12878/model.pt \\
|
||||
--graph data/processed/IMR90_chr21.pt \\
|
||||
--out results/IMR90/emb.npy
|
||||
"""
|
||||
|
||||
import argparse, torch, numpy as np
|
||||
from torch_geometric.nn import GCNConv, VGAE
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
|
||||
# Reuse your Encoder definition directly here for clarity
|
||||
class Encoder(torch.nn.Module):
|
||||
def __init__(self, in_dim, hidden, latent, dropout=0.2):
|
||||
super().__init__()
|
||||
self.gc1 = GCNConv(in_dim, hidden)
|
||||
self.gc_mu = GCNConv(hidden, latent)
|
||||
self.gc_log = GCNConv(hidden, latent)
|
||||
self.dropout = dropout
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch_geometric.nn.models import VGAE
|
||||
|
||||
def forward(self, x, edge_index):
|
||||
import torch.nn.functional as F
|
||||
h = self.gc1(x, edge_index)
|
||||
h = F.relu(h)
|
||||
h = F.dropout(h, p=self.dropout, training=self.training)
|
||||
return self.gc_mu(h, edge_index), self.gc_log(h, edge_index)
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
from model import Encoder
|
||||
|
||||
|
||||
def _infer_dims(state_dict: dict) -> tuple:
|
||||
"""Infer (in_dim, hidden, latent) from a VGAE state_dict."""
|
||||
keys = list(state_dict.keys())
|
||||
|
||||
def _first_weight(substr):
|
||||
for k in keys:
|
||||
if (substr in k
|
||||
and "weight" in k
|
||||
and "running" not in k
|
||||
and "num_batches" not in k):
|
||||
return state_dict[k].shape
|
||||
raise KeyError(f"No weight key containing '{substr}' in state_dict. "
|
||||
f"Available keys: {keys}")
|
||||
|
||||
gc1_shape = _first_weight("gc1") # shape [hidden, in_dim]
|
||||
gc_mu_shape = _first_weight("gc_mu") # shape [latent, hidden]
|
||||
hidden = gc1_shape[0]
|
||||
latent = gc_mu_shape[0]
|
||||
|
||||
# in_dim from BatchNorm weight (shape [in_dim])
|
||||
for k in keys:
|
||||
if "norm" in k and k.endswith("weight") and "running" not in k:
|
||||
in_dim = state_dict[k].shape[0]
|
||||
break
|
||||
else:
|
||||
in_dim = gc1_shape[1] # fallback: second dim of gc1 weight
|
||||
|
||||
return in_dim, hidden, latent
|
||||
|
||||
|
||||
def main():
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument("--model", required=True)
|
||||
p.add_argument("--graph", required=True)
|
||||
p.add_argument("--out", required=True)
|
||||
p = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
|
||||
)
|
||||
p.add_argument("--model", required=True,
|
||||
help="Path to model.pt saved by train_vgae.py")
|
||||
p.add_argument("--graph", required=True,
|
||||
help="Path to Data .pt file from build_graph.py")
|
||||
p.add_argument("--out", required=True,
|
||||
help="Output .npy path for node embeddings")
|
||||
args = p.parse_args()
|
||||
|
||||
# ---- Load data and model state ----
|
||||
data = torch.load(args.graph)
|
||||
model_state = torch.load(args.model, map_location="cpu")
|
||||
data = torch.load(args.graph, weights_only=False)
|
||||
state_dict = torch.load(args.model, map_location="cpu", weights_only=False)
|
||||
|
||||
# ---- Infer dimensions dynamically ----
|
||||
in_dim = data.x.size(1)
|
||||
# detect hidden and latent dimensions safely
|
||||
keys = list(model_state.keys())
|
||||
gc1_weight = [k for k in keys if "gc1" in k and "weight" in k][0]
|
||||
gc_mu_weight = [k for k in keys if "gc_mu" in k and "weight" in k][0]
|
||||
in_dim, hidden, latent = _infer_dims(state_dict)
|
||||
print(f"Inferred: in_dim={in_dim} hidden={hidden} latent={latent}")
|
||||
|
||||
hidden = model_state[gc1_weight].shape[0]
|
||||
latent = model_state[gc_mu_weight].shape[0]
|
||||
|
||||
print(f"Inferred dims: in={in_dim}, hidden={hidden}, latent={latent}")
|
||||
|
||||
enc = Encoder(in_dim=in_dim, hidden=hidden, latent=latent)
|
||||
enc = Encoder(in_dim=in_dim, hidden=hidden, latent=latent)
|
||||
model = VGAE(enc)
|
||||
model.load_state_dict(model_state)
|
||||
model.load_state_dict(state_dict)
|
||||
model.eval()
|
||||
|
||||
# ---- Encode ----
|
||||
with torch.no_grad():
|
||||
z = model.encode(data.x.float(), data.edge_index)
|
||||
|
||||
os.makedirs(os.path.dirname(os.path.abspath(args.out)), exist_ok=True)
|
||||
np.save(args.out, z.cpu().numpy())
|
||||
print(f"Saved embeddings → {args.out} shape={z.shape}")
|
||||
print(f"Saved embeddings → {args.out} shape={z.shape}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user