fix: three silent bugs in training and inference
- Edge weights (ICE log1p) were computed but never passed to GCNConv - encode_graph.py hardcoded GCNEncoder regardless of saved model type - Inference graph lacked to_undirected + edge_weight pipeline from training
This commit is contained in:
@@ -2,9 +2,10 @@
|
||||
"""
|
||||
Encode a chromatin contact graph using a trained VGAE model.
|
||||
|
||||
Dimensions (in_dim, hidden, latent) are inferred automatically from the saved
|
||||
state_dict. The BatchNorm running statistics from training are restored, so the
|
||||
same normalisation is applied to held-out cell lines without a separate scaler.
|
||||
Dimensions (in_dim, hidden, latent) and encoder type are read from the
|
||||
metrics.json saved alongside model.pt. Edge weights are passed to GCN-based
|
||||
encoders so the same weighted message-passing used during training is applied
|
||||
at inference time.
|
||||
|
||||
Usage
|
||||
-----
|
||||
@@ -15,19 +16,30 @@ Usage
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch_geometric.nn.models import VGAE
|
||||
from torch_geometric.utils import remove_self_loops, to_undirected
|
||||
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
from model import Encoder
|
||||
from model import build_encoder, GCNEncoder
|
||||
|
||||
|
||||
def _infer_dims(state_dict: dict) -> tuple:
|
||||
"""Infer (in_dim, hidden, latent) from a VGAE state_dict."""
|
||||
def _load_metrics(model_path: str) -> dict:
|
||||
"""Read metrics.json from the same directory as model.pt."""
|
||||
metrics_path = os.path.join(os.path.dirname(os.path.abspath(model_path)), "metrics.json")
|
||||
if os.path.exists(metrics_path):
|
||||
with open(metrics_path) as f:
|
||||
return json.load(f)
|
||||
return {}
|
||||
|
||||
|
||||
def _infer_dims_gcn(state_dict: dict) -> tuple:
|
||||
"""Fallback: infer (in_dim, hidden, latent) from a GCN state_dict."""
|
||||
keys = list(state_dict.keys())
|
||||
|
||||
def _first_weight(substr):
|
||||
@@ -37,22 +49,18 @@ def _infer_dims(state_dict: dict) -> tuple:
|
||||
and "running" not in k
|
||||
and "num_batches" not in k):
|
||||
return state_dict[k].shape
|
||||
raise KeyError(f"No weight key containing '{substr}' in state_dict. "
|
||||
f"Available keys: {keys}")
|
||||
raise KeyError(f"No weight key containing '{substr}'. Keys: {keys}")
|
||||
|
||||
gc1_shape = _first_weight("gc1") # shape [hidden, in_dim]
|
||||
gc_mu_shape = _first_weight("gc_mu") # shape [latent, hidden]
|
||||
gc1_shape = _first_weight("gc1")
|
||||
gc_mu_shape = _first_weight("gc_mu")
|
||||
hidden = gc1_shape[0]
|
||||
latent = gc_mu_shape[0]
|
||||
|
||||
# in_dim from BatchNorm weight (shape [in_dim])
|
||||
for k in keys:
|
||||
if "norm" in k and k.endswith("weight") and "running" not in k:
|
||||
in_dim = state_dict[k].shape[0]
|
||||
break
|
||||
else:
|
||||
in_dim = gc1_shape[1] # fallback: second dim of gc1 weight
|
||||
|
||||
in_dim = gc1_shape[1]
|
||||
return in_dim, hidden, latent
|
||||
|
||||
|
||||
@@ -71,16 +79,37 @@ def main():
|
||||
data = torch.load(args.graph, weights_only=False)
|
||||
state_dict = torch.load(args.model, map_location="cpu", weights_only=False)
|
||||
|
||||
in_dim, hidden, latent = _infer_dims(state_dict)
|
||||
print(f"Inferred: in_dim={in_dim} hidden={hidden} latent={latent}")
|
||||
# Build edge_index and edge_weight (undirected, consistent with training)
|
||||
ei, _ = remove_self_loops(data.edge_index)
|
||||
if hasattr(data, "edge_weight") and data.edge_weight is not None:
|
||||
ei, ew = to_undirected(ei, data.edge_weight, num_nodes=data.num_nodes)
|
||||
else:
|
||||
ei = to_undirected(ei, num_nodes=data.num_nodes)
|
||||
ew = None
|
||||
|
||||
enc = Encoder(in_dim=in_dim, hidden=hidden, latent=latent)
|
||||
# Read hyperparameters from metrics.json (preferred) or infer from state_dict
|
||||
metrics = _load_metrics(args.model)
|
||||
encoder_name = metrics.get("encoder", "gcn")
|
||||
hidden = metrics.get("hidden")
|
||||
latent = metrics.get("latent")
|
||||
heads = metrics.get("heads") or 4
|
||||
|
||||
if hidden is None or latent is None:
|
||||
print("metrics.json not found or incomplete — inferring dims from state_dict (GCN only)")
|
||||
_, hidden, latent = _infer_dims_gcn(state_dict)
|
||||
encoder_name = "gcn"
|
||||
|
||||
in_dim = data.x.shape[1]
|
||||
print(f"Encoder: {encoder_name} in_dim={in_dim} hidden={hidden} latent={latent}")
|
||||
|
||||
enc = build_encoder(encoder_name, in_dim=in_dim, hidden=hidden,
|
||||
latent=latent, dropout=0.0, heads=heads)
|
||||
model = VGAE(enc)
|
||||
model.load_state_dict(state_dict)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
z = model.encode(data.x.float(), data.edge_index)
|
||||
z = model.encode(data.x.float(), ei, ew)
|
||||
|
||||
os.makedirs(os.path.dirname(os.path.abspath(args.out)), exist_ok=True)
|
||||
np.save(args.out, z.cpu().numpy())
|
||||
|
||||
Reference in New Issue
Block a user