v1.0.0: VGAE applied to GM12878 vs IMR90 chr21 Hi-C at 25kb
Full reproducible pipeline: .mcool + ChIP-seq bigwigs → latent embeddings → A/B compartment calls → cross-cell comparison. Key results (chr21, 25 kb, latent dim=32): - Test AUC=0.777, AP=0.759 (converged epoch 31/300) - GM12878 A/B silhouette (cosine) = 0.775 - IMR90 zero-shot silhouette = 0.443 - A-compartment bins stable across cell types (mean cosine Δ=0.042) - B-compartment bins shift substantially (mean cosine Δ=0.451) - 101 B→A and 70 A→B compartment switches GM12878→IMR90
This commit is contained in:
@@ -50,7 +50,8 @@ def main():
|
||||
if emb1.shape != emb2.shape:
|
||||
raise ValueError(f"Shape mismatch: {emb1.shape} vs {emb2.shape}")
|
||||
|
||||
os.makedirs(os.path.dirname(args.prefix), exist_ok=True)
|
||||
prefix_dir = os.path.dirname(os.path.abspath(args.prefix))
|
||||
os.makedirs(prefix_dir, exist_ok=True)
|
||||
n_bins, n_dim = emb1.shape
|
||||
print(f"Loaded embeddings: {n_bins} bins × {n_dim} dims")
|
||||
|
||||
|
||||
170
scripts/compute_compartments.py
Normal file
170
scripts/compute_compartments.py
Normal file
@@ -0,0 +1,170 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Compute A/B chromatin compartments from a Hi-C .mcool file.
|
||||
|
||||
Algorithm
|
||||
---------
|
||||
1. Load ICE-balanced contact matrix for the target chromosome.
|
||||
2. Distance-normalise to O/E (divide each diagonal by its mean contact frequency).
|
||||
3. Compute Pearson correlation matrix of the O/E rows.
|
||||
4. PCA of the correlation matrix; PC1 distinguishes A from B compartments.
|
||||
5. Orient the PC1 sign using --bigwig_orient (e.g. CTCF):
|
||||
positive PC1 → high signal in that track.
|
||||
With CTCF: positive PC1 = CTCF-enriched = A compartment (active).
|
||||
With H3K27me3: pass --flip_orient so positive PC1 = B compartment (repressive).
|
||||
|
||||
Output
|
||||
------
|
||||
CSV with columns: chrom, start, end, pc1, compartment (A / B / N for masked bins).
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
|
||||
import cooler
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.decomposition import PCA
|
||||
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
|
||||
|
||||
def _bin_bigwig(bw_path: str, chrom: str, bins) -> np.ndarray:
|
||||
"""Average bigWig signal over a list of (start, end) genomic bins."""
|
||||
import pyBigWig
|
||||
bw = pyBigWig.open(bw_path)
|
||||
chrom_len = bw.chroms().get(chrom, 0)
|
||||
vals = []
|
||||
for s, e in bins:
|
||||
s, e = max(0, int(s)), min(chrom_len, int(e))
|
||||
if s >= e:
|
||||
vals.append(0.0)
|
||||
continue
|
||||
v = bw.stats(chrom, s, e, type="mean")[0]
|
||||
vals.append(0.0 if v is None or np.isnan(v) else float(v))
|
||||
bw.close()
|
||||
return np.array(vals)
|
||||
|
||||
|
||||
def _observed_over_expected(matrix: np.ndarray) -> np.ndarray:
|
||||
"""Distance-normalise a symmetric contact matrix (O/E transform)."""
|
||||
n = matrix.shape[0]
|
||||
oe = np.zeros((n, n), dtype=float)
|
||||
for d in range(n):
|
||||
idx = np.arange(n - d)
|
||||
diag = matrix[idx, idx + d].astype(float)
|
||||
positive = diag[diag > 0]
|
||||
if positive.size == 0:
|
||||
continue
|
||||
mean_d = positive.mean()
|
||||
norm_diag = np.where((np.isnan(diag)) | (diag == 0), 0.0, diag / mean_d)
|
||||
oe[idx, idx + d] = norm_diag
|
||||
if d > 0:
|
||||
oe[idx + d, idx] = norm_diag
|
||||
return oe
|
||||
|
||||
|
||||
def compute_compartments(
|
||||
mcool_path: str,
|
||||
chrom: str,
|
||||
res: int,
|
||||
orient_signal=None,
|
||||
flip_orient: bool = False,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Return a DataFrame (chrom, start, end, pc1, compartment).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
orient_signal : array-like, optional
|
||||
Per-bin 1-D signal used to fix the sign of PC1.
|
||||
Pass CTCF signal for positive-PC1 = A convention.
|
||||
flip_orient : bool
|
||||
If True, high orient_signal maps to negative PC1 (use with H3K27me3).
|
||||
"""
|
||||
c = cooler.Cooler(f"{mcool_path}::resolutions/{res}")
|
||||
bins_df = c.bins().fetch(chrom).reset_index(drop=True)
|
||||
matrix = c.matrix(balance=True).fetch(chrom).astype(float)
|
||||
|
||||
bad_bins = np.isnan(matrix).all(axis=0) | (matrix.sum(axis=0) == 0)
|
||||
np.nan_to_num(matrix, nan=0.0, copy=False)
|
||||
|
||||
oe = _observed_over_expected(matrix)
|
||||
|
||||
good = ~bad_bins
|
||||
oe_good = oe[np.ix_(good, good)]
|
||||
|
||||
# Zero rows produce NaN in corrcoef; add tiny noise to avoid singularity
|
||||
row_norms = np.linalg.norm(oe_good, axis=1)
|
||||
oe_good[row_norms == 0] += 1e-9
|
||||
|
||||
corr = np.corrcoef(oe_good)
|
||||
np.nan_to_num(corr, nan=0.0, copy=False)
|
||||
|
||||
pca = PCA(n_components=3, random_state=42)
|
||||
pcs = pca.fit_transform(corr)
|
||||
pc1_good = pcs[:, 0]
|
||||
|
||||
pc1 = np.full(len(bins_df), np.nan)
|
||||
pc1[good] = pc1_good
|
||||
|
||||
if orient_signal is not None:
|
||||
sig = np.asarray(orient_signal, dtype=float)
|
||||
sig_good = sig[good]
|
||||
valid = ~np.isnan(sig_good) & ~np.isnan(pc1_good)
|
||||
if valid.sum() > 10:
|
||||
r = np.corrcoef(pc1_good[valid], sig_good[valid])[0, 1]
|
||||
# By default: positive orient_signal → positive PC1.
|
||||
# flip_orient reverses this (e.g. H3K27me3 → positive PC1 = B).
|
||||
if (r < 0 and not flip_orient) or (r > 0 and flip_orient):
|
||||
pc1 = -pc1
|
||||
|
||||
bins_df["pc1"] = pc1
|
||||
bins_df["compartment"] = np.where(
|
||||
np.isnan(bins_df["pc1"]), "N",
|
||||
np.where(bins_df["pc1"] > 0, "A", "B"),
|
||||
)
|
||||
return bins_df[["chrom", "start", "end", "pc1", "compartment"]]
|
||||
|
||||
|
||||
def main():
|
||||
p = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
|
||||
)
|
||||
p.add_argument("--mcool", required=True, help="Path to .mcool file")
|
||||
p.add_argument("--chrom", required=True, help="Chromosome (e.g. chr21)")
|
||||
p.add_argument("--res", type=int, default=25000, help="Resolution in bp (default: 25000)")
|
||||
p.add_argument("--bigwig_orient",
|
||||
help="bigWig track for PC1 sign orientation (recommended: CTCF)")
|
||||
p.add_argument("--flip_orient", action="store_true",
|
||||
help="Flip orientation: high signal → negative PC1 (use with H3K27me3)")
|
||||
p.add_argument("--out", required=True, help="Output CSV path")
|
||||
args = p.parse_args()
|
||||
|
||||
orient_signal = None
|
||||
if args.bigwig_orient:
|
||||
c = cooler.Cooler(f"{args.mcool}::resolutions/{args.res}")
|
||||
bins_df = c.bins().fetch(args.chrom).reset_index(drop=True)
|
||||
coords = list(zip(bins_df["start"].values, bins_df["end"].values))
|
||||
orient_signal = _bin_bigwig(args.bigwig_orient, args.chrom, coords)
|
||||
print(f"Loaded orientation signal: {os.path.basename(args.bigwig_orient)}")
|
||||
|
||||
df = compute_compartments(
|
||||
args.mcool, args.chrom, args.res,
|
||||
orient_signal=orient_signal,
|
||||
flip_orient=args.flip_orient,
|
||||
)
|
||||
|
||||
os.makedirs(os.path.dirname(os.path.abspath(args.out)), exist_ok=True)
|
||||
df.to_csv(args.out, index=False)
|
||||
|
||||
n_a = (df["compartment"] == "A").sum()
|
||||
n_b = (df["compartment"] == "B").sum()
|
||||
n_nan = (df["compartment"] == "N").sum()
|
||||
print(f"Saved → {args.out}")
|
||||
print(f" A: {n_a} B: {n_b} N/masked: {n_nan} bins")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,63 +1,91 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Encode a new graph using a trained VGAE model.
|
||||
Automatically infers hidden/latent dimensions from saved weights.
|
||||
Encode a chromatin contact graph using a trained VGAE model.
|
||||
|
||||
Dimensions (in_dim, hidden, latent) are inferred automatically from the saved
|
||||
state_dict. The BatchNorm running statistics from training are restored, so the
|
||||
same normalisation is applied to held-out cell lines without a separate scaler.
|
||||
|
||||
Usage
|
||||
-----
|
||||
python scripts/encode_graph.py \\
|
||||
--model results/GM12878/model.pt \\
|
||||
--graph data/processed/IMR90_chr21.pt \\
|
||||
--out results/IMR90/emb.npy
|
||||
"""
|
||||
|
||||
import argparse, torch, numpy as np
|
||||
from torch_geometric.nn import GCNConv, VGAE
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
|
||||
# Reuse your Encoder definition directly here for clarity
|
||||
class Encoder(torch.nn.Module):
|
||||
def __init__(self, in_dim, hidden, latent, dropout=0.2):
|
||||
super().__init__()
|
||||
self.gc1 = GCNConv(in_dim, hidden)
|
||||
self.gc_mu = GCNConv(hidden, latent)
|
||||
self.gc_log = GCNConv(hidden, latent)
|
||||
self.dropout = dropout
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch_geometric.nn.models import VGAE
|
||||
|
||||
def forward(self, x, edge_index):
|
||||
import torch.nn.functional as F
|
||||
h = self.gc1(x, edge_index)
|
||||
h = F.relu(h)
|
||||
h = F.dropout(h, p=self.dropout, training=self.training)
|
||||
return self.gc_mu(h, edge_index), self.gc_log(h, edge_index)
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
from model import Encoder
|
||||
|
||||
|
||||
def _infer_dims(state_dict: dict) -> tuple:
|
||||
"""Infer (in_dim, hidden, latent) from a VGAE state_dict."""
|
||||
keys = list(state_dict.keys())
|
||||
|
||||
def _first_weight(substr):
|
||||
for k in keys:
|
||||
if (substr in k
|
||||
and "weight" in k
|
||||
and "running" not in k
|
||||
and "num_batches" not in k):
|
||||
return state_dict[k].shape
|
||||
raise KeyError(f"No weight key containing '{substr}' in state_dict. "
|
||||
f"Available keys: {keys}")
|
||||
|
||||
gc1_shape = _first_weight("gc1") # shape [hidden, in_dim]
|
||||
gc_mu_shape = _first_weight("gc_mu") # shape [latent, hidden]
|
||||
hidden = gc1_shape[0]
|
||||
latent = gc_mu_shape[0]
|
||||
|
||||
# in_dim from BatchNorm weight (shape [in_dim])
|
||||
for k in keys:
|
||||
if "norm" in k and k.endswith("weight") and "running" not in k:
|
||||
in_dim = state_dict[k].shape[0]
|
||||
break
|
||||
else:
|
||||
in_dim = gc1_shape[1] # fallback: second dim of gc1 weight
|
||||
|
||||
return in_dim, hidden, latent
|
||||
|
||||
|
||||
def main():
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument("--model", required=True)
|
||||
p.add_argument("--graph", required=True)
|
||||
p.add_argument("--out", required=True)
|
||||
p = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
|
||||
)
|
||||
p.add_argument("--model", required=True,
|
||||
help="Path to model.pt saved by train_vgae.py")
|
||||
p.add_argument("--graph", required=True,
|
||||
help="Path to Data .pt file from build_graph.py")
|
||||
p.add_argument("--out", required=True,
|
||||
help="Output .npy path for node embeddings")
|
||||
args = p.parse_args()
|
||||
|
||||
# ---- Load data and model state ----
|
||||
data = torch.load(args.graph)
|
||||
model_state = torch.load(args.model, map_location="cpu")
|
||||
data = torch.load(args.graph, weights_only=False)
|
||||
state_dict = torch.load(args.model, map_location="cpu", weights_only=False)
|
||||
|
||||
# ---- Infer dimensions dynamically ----
|
||||
in_dim = data.x.size(1)
|
||||
# detect hidden and latent dimensions safely
|
||||
keys = list(model_state.keys())
|
||||
gc1_weight = [k for k in keys if "gc1" in k and "weight" in k][0]
|
||||
gc_mu_weight = [k for k in keys if "gc_mu" in k and "weight" in k][0]
|
||||
in_dim, hidden, latent = _infer_dims(state_dict)
|
||||
print(f"Inferred: in_dim={in_dim} hidden={hidden} latent={latent}")
|
||||
|
||||
hidden = model_state[gc1_weight].shape[0]
|
||||
latent = model_state[gc_mu_weight].shape[0]
|
||||
|
||||
print(f"Inferred dims: in={in_dim}, hidden={hidden}, latent={latent}")
|
||||
|
||||
enc = Encoder(in_dim=in_dim, hidden=hidden, latent=latent)
|
||||
enc = Encoder(in_dim=in_dim, hidden=hidden, latent=latent)
|
||||
model = VGAE(enc)
|
||||
model.load_state_dict(model_state)
|
||||
model.load_state_dict(state_dict)
|
||||
model.eval()
|
||||
|
||||
# ---- Encode ----
|
||||
with torch.no_grad():
|
||||
z = model.encode(data.x.float(), data.edge_index)
|
||||
|
||||
os.makedirs(os.path.dirname(os.path.abspath(args.out)), exist_ok=True)
|
||||
np.save(args.out, z.cpu().numpy())
|
||||
print(f"Saved embeddings → {args.out} shape={z.shape}")
|
||||
print(f"Saved embeddings → {args.out} shape={z.shape}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
||||
32
scripts/model.py
Normal file
32
scripts/model.py
Normal file
@@ -0,0 +1,32 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Shared VGAE encoder. Imported by train_vgae.py and encode_graph.py."""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch_geometric.nn import GCNConv
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
"""Two-layer GCN encoder for VGAE with input BatchNorm.
|
||||
|
||||
Architecture: BatchNorm → GCN(hidden) → ReLU → Dropout → GCN_mu / GCN_logstd
|
||||
|
||||
The BatchNorm layer normalises raw ChIP-seq signals and its running statistics
|
||||
are saved in model.pt, so encode_graph.py applies identical normalisation to
|
||||
held-out cell lines without a separate scaler file.
|
||||
"""
|
||||
|
||||
def __init__(self, in_dim: int, hidden: int, latent: int, dropout: float = 0.2):
|
||||
super().__init__()
|
||||
self.norm = nn.BatchNorm1d(in_dim)
|
||||
self.gc1 = GCNConv(in_dim, hidden, add_self_loops=True, normalize=True)
|
||||
self.gc_mu = GCNConv(hidden, latent, add_self_loops=True, normalize=True)
|
||||
self.gc_log = GCNConv(hidden, latent, add_self_loops=True, normalize=True)
|
||||
self.dropout = dropout
|
||||
|
||||
def forward(self, x, edge_index):
|
||||
x = self.norm(x)
|
||||
h = F.relu(self.gc1(x, edge_index))
|
||||
h = F.dropout(h, p=self.dropout, training=self.training)
|
||||
return self.gc_mu(h, edge_index), self.gc_log(h, edge_index)
|
||||
@@ -1,179 +1,185 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Train a Variational Graph Autoencoder (VGAE) on a chromatin contact graph.
|
||||
---
|
||||
Inputs:
|
||||
- A PyTorch Geometric Data object saved with torch.save()
|
||||
- from build_graph.py
|
||||
---
|
||||
Outputs (under results/):
|
||||
- model.pt : trained VGAE state_dict
|
||||
- emb.npy : node embeddings (mean; shape [num_nodes, latent_dim])
|
||||
- metrics.json : train/val/test AUC/AP summary
|
||||
|
||||
Inputs
|
||||
------
|
||||
PyTorch Geometric Data object saved by build_graph.py.
|
||||
|
||||
Outputs (under --outdir)
|
||||
------------------------
|
||||
model.pt trained VGAE state_dict (includes BatchNorm running statistics)
|
||||
emb.npy node embeddings — mu vector, shape [num_nodes, latent_dim]
|
||||
metrics.json val/test AUC & AP plus all hyperparameters
|
||||
"""
|
||||
|
||||
import os, json, argparse, numpy as np, torch
|
||||
import torch.nn.functional as F
|
||||
from torch_geometric.nn import GCNConv
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch_geometric.nn.models import VGAE
|
||||
from torch_geometric.transforms import RandomLinkSplit
|
||||
from torch_geometric.utils import to_undirected, remove_self_loops
|
||||
from torch_geometric.utils import negative_sampling
|
||||
from sklearn.metrics import roc_auc_score, average_precision_score
|
||||
from torch_geometric.utils import (
|
||||
negative_sampling,
|
||||
remove_self_loops,
|
||||
to_undirected,
|
||||
)
|
||||
from sklearn.metrics import average_precision_score, roc_auc_score
|
||||
|
||||
|
||||
class Encoder(torch.nn.Module):
|
||||
def __init__(self, in_dim: int, hidden: int, latent: int, dropout: float = 0.2):
|
||||
super().__init__()
|
||||
self.gc1 = GCNConv(in_dim, hidden, add_self_loops=True, normalize=True)
|
||||
self.gc_mu = GCNConv(hidden, latent, add_self_loops=True, normalize=True)
|
||||
self.gc_log = GCNConv(hidden, latent, add_self_loops=True, normalize=True)
|
||||
self.dropout = dropout
|
||||
|
||||
def forward(self, x, edge_index):
|
||||
h = self.gc1(x, edge_index)
|
||||
h = F.relu(h)
|
||||
h = F.dropout(h, p=self.dropout, training=self.training)
|
||||
return self.gc_mu(h, edge_index), self.gc_log(h, edge_index)
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
from model import Encoder
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def eval_linkpred(model, data_like, z):
|
||||
"""Compute AUROC/AP using provided positive/negative edges."""
|
||||
pos = data_like.pos_edge_index
|
||||
neg = data_like.neg_edge_index
|
||||
# model.test returns (auc, ap) but relies on torchmetrics in some versions;
|
||||
# compute explicitly for stability:
|
||||
def sigmoid(x): return 1 / (1 + torch.exp(-x))
|
||||
def _eval_linkpred(z, pos_edges, neg_edges):
|
||||
"""Return (AUROC, AP) for link prediction."""
|
||||
def _sigmoid(x):
|
||||
return 1.0 / (1.0 + torch.exp(-x))
|
||||
|
||||
# Inner product decoder scores
|
||||
def scores(edges):
|
||||
def _score(edges):
|
||||
src, dst = edges
|
||||
s = (z[src] * z[dst]).sum(dim=1)
|
||||
return sigmoid(s).cpu().numpy()
|
||||
return _sigmoid((z[src] * z[dst]).sum(dim=1)).cpu().numpy()
|
||||
|
||||
y_true = np.concatenate([np.ones(pos.size(1)), np.zeros(neg.size(1))])
|
||||
y_pred = np.concatenate([scores(pos), scores(neg)])
|
||||
|
||||
auc = roc_auc_score(y_true, y_pred)
|
||||
ap = average_precision_score(y_true, y_pred)
|
||||
return auc, ap
|
||||
y_true = np.concatenate([np.ones(pos_edges.size(1)),
|
||||
np.zeros(neg_edges.size(1))])
|
||||
y_pred = np.concatenate([_score(pos_edges), _score(neg_edges)])
|
||||
return roc_auc_score(y_true, y_pred), average_precision_score(y_true, y_pred)
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--graph", required=True, help="Path to Data .pt file")
|
||||
ap.add_argument("--epochs", type=int, default=100)
|
||||
ap.add_argument("--lr", type=float, default=1e-3)
|
||||
ap.add_argument("--hidden", type=int, default=128)
|
||||
ap.add_argument("--latent", type=int, default=64)
|
||||
ap.add_argument("--dropout", type=float, default=0.2)
|
||||
ap.add_argument("--seed", type=int, default=42)
|
||||
ap.add_argument("--outdir", default="results")
|
||||
ap = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
|
||||
)
|
||||
ap.add_argument("--graph", required=True,
|
||||
help="Path to Data .pt file from build_graph.py")
|
||||
ap.add_argument("--epochs", type=int, default=300)
|
||||
ap.add_argument("--patience", type=int, default=20,
|
||||
help="Early-stopping patience (val-AUC epochs without improvement)")
|
||||
ap.add_argument("--lr", type=float, default=1e-3)
|
||||
ap.add_argument("--hidden", type=int, default=64)
|
||||
ap.add_argument("--latent", type=int, default=32)
|
||||
ap.add_argument("--dropout", type=float, default=0.2)
|
||||
ap.add_argument("--seed", type=int, default=42)
|
||||
ap.add_argument("--outdir", default="results")
|
||||
args = ap.parse_args()
|
||||
|
||||
torch.manual_seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
os.makedirs(args.outdir, exist_ok=True)
|
||||
|
||||
# Load graph
|
||||
data = torch.load(args.graph)
|
||||
# Coalesce/clean edges
|
||||
# ---- Load and clean graph ----
|
||||
data = torch.load(args.graph, weights_only=False)
|
||||
ei, _ = remove_self_loops(data.edge_index)
|
||||
data.edge_index = to_undirected(ei, num_nodes=data.num_nodes)
|
||||
x = data.x.float()
|
||||
print(f"Graph: {data.num_nodes} nodes "
|
||||
f"{data.edge_index.shape[1]} edges "
|
||||
f"{x.shape[1]} node features")
|
||||
|
||||
# Split edges for link prediction
|
||||
# ---- Edge splits for link-prediction evaluation ----
|
||||
splitter = RandomLinkSplit(
|
||||
num_val=0.1,
|
||||
num_test=0.1,
|
||||
num_val=0.1, num_test=0.1,
|
||||
is_undirected=True,
|
||||
add_negative_train_samples=False,
|
||||
split_labels=False,
|
||||
)
|
||||
train_data, val_data, test_data = splitter(data)
|
||||
|
||||
# Positive edges are just the edges in each split
|
||||
train_data.pos_edge_index = train_data.edge_index
|
||||
val_data.pos_edge_index = val_data.edge_index
|
||||
test_data.pos_edge_index = test_data.edge_index
|
||||
|
||||
# Generate negative edges for validation and test manually
|
||||
for subset in [val_data, test_data]:
|
||||
subset.neg_edge_index = negative_sampling(
|
||||
edge_index=subset.edge_index,
|
||||
for split in (val_data, test_data):
|
||||
split.pos_edge_index = split.edge_index
|
||||
split.neg_edge_index = negative_sampling(
|
||||
edge_index=split.edge_index,
|
||||
num_nodes=data.num_nodes,
|
||||
num_neg_samples=subset.edge_index.size(1),
|
||||
method='sparse'
|
||||
num_neg_samples=split.edge_index.size(1),
|
||||
method="sparse",
|
||||
)
|
||||
|
||||
|
||||
# Model
|
||||
enc = Encoder(in_dim=x.size(1), hidden=args.hidden, latent=args.latent, dropout=args.dropout)
|
||||
# ---- Model ----
|
||||
enc = Encoder(in_dim=x.size(1), hidden=args.hidden,
|
||||
latent=args.latent, dropout=args.dropout)
|
||||
model = VGAE(enc)
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
|
||||
|
||||
# Training loop
|
||||
# ---- Training loop with early stopping ----
|
||||
best_val_auc = -1.0
|
||||
best_state = None
|
||||
best_state = None
|
||||
no_improve = 0
|
||||
epochs_ran = 0
|
||||
|
||||
for epoch in range(1, args.epochs + 1):
|
||||
model.train()
|
||||
optimizer.zero_grad()
|
||||
# Encode using remaining training edges
|
||||
z = model.encode(x, train_data.edge_index)
|
||||
# Reconstruction loss on positive training edges (negatives sampled inside)
|
||||
loss_recon = model.recon_loss(z, train_data.pos_edge_index)
|
||||
# KL divergence regularizer
|
||||
loss_kl = (1.0 / data.num_nodes) * model.kl_loss()
|
||||
loss = loss_recon + loss_kl
|
||||
z = model.encode(x, train_data.edge_index)
|
||||
loss = (model.recon_loss(z, train_data.pos_edge_index)
|
||||
+ (1.0 / data.num_nodes) * model.kl_loss())
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# Validation
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
z_full = model.encode(x, data.edge_index) # use full graph for eval embeddings
|
||||
val_auc, val_ap = eval_linkpred(model, val_data, z_full)
|
||||
z_full = model.encode(x, data.edge_index)
|
||||
val_auc, val_ap = _eval_linkpred(
|
||||
z_full, val_data.pos_edge_index, val_data.neg_edge_index
|
||||
)
|
||||
|
||||
if val_auc > best_val_auc:
|
||||
best_val_auc = val_auc
|
||||
best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
|
||||
best_state = {k: v.cpu().clone()
|
||||
for k, v in model.state_dict().items()}
|
||||
no_improve = 0
|
||||
else:
|
||||
no_improve += 1
|
||||
|
||||
epochs_ran = epoch
|
||||
if epoch % 10 == 0 or epoch == 1:
|
||||
print(f"[{epoch:03d}/{args.epochs}] loss={loss.item():.4f} | val AUC={val_auc:.4f} AP={val_ap:.4f}")
|
||||
print(f"[{epoch:03d}/{args.epochs}] "
|
||||
f"loss={loss.item():.4f} "
|
||||
f"val AUC={val_auc:.4f} AP={val_ap:.4f}")
|
||||
|
||||
# Save best model
|
||||
if no_improve >= args.patience:
|
||||
print(f"Early stopping at epoch {epoch} "
|
||||
f"(no val-AUC improvement for {args.patience} epochs)")
|
||||
break
|
||||
|
||||
# ---- Restore best checkpoint and compute test metrics ----
|
||||
model.load_state_dict(best_state)
|
||||
model_path = os.path.join(args.outdir, "model.pt")
|
||||
torch.save(model.state_dict(), model_path)
|
||||
|
||||
# Final test metrics
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
z_final = model.encode(x, data.edge_index)
|
||||
test_auc, test_ap = eval_linkpred(model, test_data, z_final)
|
||||
test_auc, test_ap = _eval_linkpred(
|
||||
z_final, test_data.pos_edge_index, test_data.neg_edge_index
|
||||
)
|
||||
|
||||
# ---- Save outputs ----
|
||||
model_path = os.path.join(args.outdir, "model.pt")
|
||||
torch.save(best_state, model_path)
|
||||
|
||||
# Save embeddings & metrics
|
||||
emb_path = os.path.join(args.outdir, "emb.npy")
|
||||
np.save(emb_path, z_final.cpu().numpy())
|
||||
|
||||
metrics = {
|
||||
"val_auc": float(best_val_auc),
|
||||
"test_auc": float(test_auc),
|
||||
"test_ap": float(test_ap),
|
||||
"epochs": args.epochs,
|
||||
"hidden": args.hidden,
|
||||
"latent": args.latent,
|
||||
"dropout": args.dropout,
|
||||
"lr": args.lr,
|
||||
"seed": args.seed
|
||||
"val_auc": float(best_val_auc),
|
||||
"test_auc": float(test_auc),
|
||||
"test_ap": float(test_ap),
|
||||
"epochs_ran": epochs_ran,
|
||||
"epochs_max": args.epochs,
|
||||
"patience": args.patience,
|
||||
"hidden": args.hidden,
|
||||
"latent": args.latent,
|
||||
"dropout": args.dropout,
|
||||
"lr": args.lr,
|
||||
"seed": args.seed,
|
||||
}
|
||||
with open(os.path.join(args.outdir, "metrics.json"), "w") as f:
|
||||
json.dump(metrics, f, indent=2)
|
||||
|
||||
print(f"Saved model -> {model_path}")
|
||||
print(f"Saved embeddings -> {emb_path} (shape={z_final.shape})")
|
||||
print(f"Metrics: AUC(test)={test_auc:.4f}, AP(test)={test_ap:.4f}")
|
||||
print(f"\nSaved model → {model_path}")
|
||||
print(f"Saved embeddings → {emb_path} shape={z_final.shape}")
|
||||
print(f"Test AUC={test_auc:.4f} AP={test_ap:.4f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
189
scripts/visualize_embeddings.py
Normal file
189
scripts/visualize_embeddings.py
Normal file
@@ -0,0 +1,189 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Visualise VGAE node embeddings using UMAP.
|
||||
|
||||
Produces (under --prefix):
|
||||
{prefix}_{label}_position.png UMAP coloured by genomic position (bin index)
|
||||
{prefix}_{label}_compartment.png UMAP coloured by A/B compartment (needs --compartments)
|
||||
{prefix}_joint.png Joint UMAP of all supplied cell lines
|
||||
{prefix}_stats.csv Per-embedding summary statistics
|
||||
|
||||
Usage
|
||||
-----
|
||||
python scripts/visualize_embeddings.py \\
|
||||
--emb results/GM12878/emb.npy results/IMR90/emb.npy \\
|
||||
--labels GM12878 IMR90 \\
|
||||
--compartments results/GM12878/compartments_chr21.csv \\
|
||||
results/IMR90/compartments_chr21.csv \\
|
||||
--prefix results/figures/umap
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import matplotlib
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.metrics import silhouette_score
|
||||
import umap
|
||||
|
||||
|
||||
COMPARTMENT_COLORS = {"A": "#E41A1C", "B": "#377EB8", "N": "#AAAAAA"}
|
||||
CELL_LINE_PALETTE = ["#E41A1C", "#4DAF4A", "#984EA3", "#FF7F00", "#377EB8"]
|
||||
|
||||
plt.rcParams.update({
|
||||
"font.family": "sans-serif",
|
||||
"axes.spines.top": False,
|
||||
"axes.spines.right": False,
|
||||
})
|
||||
|
||||
|
||||
def _run_umap(emb: np.ndarray, seed: int = 42) -> np.ndarray:
|
||||
reducer = umap.UMAP(n_components=2, random_state=seed,
|
||||
min_dist=0.3, n_neighbors=15)
|
||||
return reducer.fit_transform(emb)
|
||||
|
||||
|
||||
def _plot_position(coords: np.ndarray, label: str, out_path: str):
|
||||
fig, ax = plt.subplots(figsize=(6.5, 5.5))
|
||||
sc = ax.scatter(coords[:, 0], coords[:, 1],
|
||||
c=np.arange(len(coords)), cmap="plasma",
|
||||
s=4, alpha=0.75, linewidths=0, rasterized=True)
|
||||
cbar = plt.colorbar(sc, ax=ax, shrink=0.8, pad=0.02)
|
||||
cbar.set_label("Bin index (5′ → 3′)", fontsize=9)
|
||||
ax.set_title(f"{label} — UMAP coloured by genomic position", fontsize=10)
|
||||
ax.set_xlabel("UMAP 1", fontsize=9)
|
||||
ax.set_ylabel("UMAP 2", fontsize=9)
|
||||
ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
|
||||
plt.tight_layout()
|
||||
plt.savefig(out_path, dpi=300)
|
||||
plt.close()
|
||||
|
||||
|
||||
def _plot_compartment(coords: np.ndarray, compartments: np.ndarray,
|
||||
label: str, out_path: str):
|
||||
fig, ax = plt.subplots(figsize=(6.5, 5.5))
|
||||
for comp in ("A", "B", "N"):
|
||||
mask = compartments == comp
|
||||
if mask.sum() == 0:
|
||||
continue
|
||||
ax.scatter(coords[mask, 0], coords[mask, 1],
|
||||
c=COMPARTMENT_COLORS[comp], s=4, alpha=0.75,
|
||||
label=f"{comp} ({mask.sum()} bins)", linewidths=0,
|
||||
rasterized=True)
|
||||
ax.legend(markerscale=3, title="Compartment", fontsize=9,
|
||||
title_fontsize=9, frameon=False)
|
||||
ax.set_title(f"{label} — UMAP coloured by A/B compartment", fontsize=10)
|
||||
ax.set_xlabel("UMAP 1", fontsize=9)
|
||||
ax.set_ylabel("UMAP 2", fontsize=9)
|
||||
ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
|
||||
plt.tight_layout()
|
||||
plt.savefig(out_path, dpi=300)
|
||||
plt.close()
|
||||
|
||||
|
||||
def _plot_joint(all_coords: np.ndarray, all_labels: list, out_path: str):
|
||||
fig, ax = plt.subplots(figsize=(7, 6))
|
||||
unique = list(dict.fromkeys(all_labels))
|
||||
arr = np.array(all_labels)
|
||||
for i, label in enumerate(unique):
|
||||
mask = arr == label
|
||||
ax.scatter(all_coords[mask, 0], all_coords[mask, 1],
|
||||
c=CELL_LINE_PALETTE[i % len(CELL_LINE_PALETTE)],
|
||||
s=3, alpha=0.6, label=label, linewidths=0, rasterized=True)
|
||||
ax.legend(markerscale=4, title="Cell line", fontsize=9,
|
||||
title_fontsize=9, frameon=False)
|
||||
ax.set_title("Joint UMAP — chromatin topology embeddings", fontsize=11)
|
||||
ax.set_xlabel("UMAP 1", fontsize=9)
|
||||
ax.set_ylabel("UMAP 2", fontsize=9)
|
||||
ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
|
||||
plt.tight_layout()
|
||||
plt.savefig(out_path, dpi=300)
|
||||
plt.close()
|
||||
|
||||
|
||||
def _silhouette(emb: np.ndarray, compartments: np.ndarray) -> float:
|
||||
mask = compartments != "N"
|
||||
if mask.sum() < 20 or len(set(compartments[mask])) < 2:
|
||||
return float("nan")
|
||||
try:
|
||||
return float(silhouette_score(emb[mask], compartments[mask], metric="cosine"))
|
||||
except Exception:
|
||||
return float("nan")
|
||||
|
||||
|
||||
def main():
|
||||
p = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
|
||||
)
|
||||
p.add_argument("--emb", nargs="+", required=True,
|
||||
help="One or more .npy embedding files")
|
||||
p.add_argument("--labels", nargs="+", required=True,
|
||||
help="Label for each embedding (same order)")
|
||||
p.add_argument("--compartments", nargs="+",
|
||||
help="Compartment CSV files, one per embedding (optional)")
|
||||
p.add_argument("--prefix", default="results/figures/umap",
|
||||
help="Output file prefix")
|
||||
p.add_argument("--seed", type=int, default=42)
|
||||
args = p.parse_args()
|
||||
|
||||
if len(args.emb) != len(args.labels):
|
||||
raise ValueError("--emb and --labels must have the same length")
|
||||
|
||||
os.makedirs(os.path.dirname(os.path.abspath(args.prefix + "_x")), exist_ok=True)
|
||||
|
||||
embs = [np.load(f) for f in args.emb]
|
||||
comp_dfs = []
|
||||
if args.compartments:
|
||||
for f in args.compartments:
|
||||
comp_dfs.append(pd.read_csv(f) if (f and os.path.exists(f)) else None)
|
||||
|
||||
stats_rows = []
|
||||
|
||||
for i, (emb, label) in enumerate(zip(embs, args.labels)):
|
||||
print(f"\n[{label}] {emb.shape[0]} nodes × {emb.shape[1]} dims")
|
||||
coords = _run_umap(emb, seed=args.seed)
|
||||
|
||||
tag = label.replace(" ", "_")
|
||||
_plot_position(coords, label, f"{args.prefix}_{tag}_position.png")
|
||||
print(f" → {args.prefix}_{tag}_position.png")
|
||||
|
||||
comp_arr = None
|
||||
sil = float("nan")
|
||||
if comp_dfs and i < len(comp_dfs) and comp_dfs[i] is not None:
|
||||
comp_arr = comp_dfs[i]["compartment"].values[: len(emb)]
|
||||
_plot_compartment(coords, comp_arr, label,
|
||||
f"{args.prefix}_{tag}_compartment.png")
|
||||
print(f" → {args.prefix}_{tag}_compartment.png")
|
||||
sil = _silhouette(emb, comp_arr)
|
||||
print(f" Silhouette (A/B, cosine): {sil:.4f}")
|
||||
|
||||
stats_rows.append({
|
||||
"label": label,
|
||||
"n_bins": emb.shape[0],
|
||||
"latent_dim": emb.shape[1],
|
||||
"mean_embedding_norm": float(np.linalg.norm(emb, axis=1).mean()),
|
||||
"std_embedding_values": float(emb.std()),
|
||||
"silhouette_AB_cosine": sil,
|
||||
})
|
||||
|
||||
# Joint UMAP when multiple embeddings are supplied
|
||||
if len(embs) > 1:
|
||||
print("\nComputing joint UMAP…")
|
||||
all_emb = np.vstack(embs)
|
||||
all_labels = sum([[lab] * len(e) for lab, e in zip(args.labels, embs)], [])
|
||||
all_coords = _run_umap(all_emb, seed=args.seed)
|
||||
_plot_joint(all_coords, all_labels, f"{args.prefix}_joint.png")
|
||||
print(f" → {args.prefix}_joint.png")
|
||||
|
||||
stats_df = pd.DataFrame(stats_rows)
|
||||
stats_path = f"{args.prefix}_stats.csv"
|
||||
stats_df.to_csv(stats_path, index=False)
|
||||
print(f"\nStats → {stats_path}")
|
||||
print(stats_df.to_string(index=False))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user