v1.0.0: VGAE applied to GM12878 vs IMR90 chr21 Hi-C at 25kb

Full reproducible pipeline: .mcool + ChIP-seq bigwigs → latent
  embeddings → A/B compartment calls → cross-cell comparison.

  Key results (chr21, 25 kb, latent dim=32):
  - Test AUC=0.777, AP=0.759 (converged epoch 31/300)
  - GM12878 A/B silhouette (cosine) = 0.775
  - IMR90 zero-shot silhouette = 0.443
  - A-compartment bins stable across cell types (mean cosine Δ=0.042)
  - B-compartment bins shift substantially (mean cosine Δ=0.451)
  - 101 B→A and 70 A→B compartment switches GM12878→IMR90
This commit is contained in:
2026-05-15 01:53:04 +02:00
parent 6c91af655d
commit acadbd780c
27 changed files with 6764 additions and 201 deletions

View File

@@ -0,0 +1,189 @@
#!/usr/bin/env python3
"""
Visualise VGAE node embeddings using UMAP.
Produces (under --prefix):
{prefix}_{label}_position.png UMAP coloured by genomic position (bin index)
{prefix}_{label}_compartment.png UMAP coloured by A/B compartment (needs --compartments)
{prefix}_joint.png Joint UMAP of all supplied cell lines
{prefix}_stats.csv Per-embedding summary statistics
Usage
-----
python scripts/visualize_embeddings.py \\
--emb results/GM12878/emb.npy results/IMR90/emb.npy \\
--labels GM12878 IMR90 \\
--compartments results/GM12878/compartments_chr21.csv \\
results/IMR90/compartments_chr21.csv \\
--prefix results/figures/umap
"""
import argparse
import os
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.metrics import silhouette_score
import umap
COMPARTMENT_COLORS = {"A": "#E41A1C", "B": "#377EB8", "N": "#AAAAAA"}
CELL_LINE_PALETTE = ["#E41A1C", "#4DAF4A", "#984EA3", "#FF7F00", "#377EB8"]
plt.rcParams.update({
"font.family": "sans-serif",
"axes.spines.top": False,
"axes.spines.right": False,
})
def _run_umap(emb: np.ndarray, seed: int = 42) -> np.ndarray:
reducer = umap.UMAP(n_components=2, random_state=seed,
min_dist=0.3, n_neighbors=15)
return reducer.fit_transform(emb)
def _plot_position(coords: np.ndarray, label: str, out_path: str):
fig, ax = plt.subplots(figsize=(6.5, 5.5))
sc = ax.scatter(coords[:, 0], coords[:, 1],
c=np.arange(len(coords)), cmap="plasma",
s=4, alpha=0.75, linewidths=0, rasterized=True)
cbar = plt.colorbar(sc, ax=ax, shrink=0.8, pad=0.02)
cbar.set_label("Bin index (5 → 3)", fontsize=9)
ax.set_title(f"{label} — UMAP coloured by genomic position", fontsize=10)
ax.set_xlabel("UMAP 1", fontsize=9)
ax.set_ylabel("UMAP 2", fontsize=9)
ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
plt.tight_layout()
plt.savefig(out_path, dpi=300)
plt.close()
def _plot_compartment(coords: np.ndarray, compartments: np.ndarray,
label: str, out_path: str):
fig, ax = plt.subplots(figsize=(6.5, 5.5))
for comp in ("A", "B", "N"):
mask = compartments == comp
if mask.sum() == 0:
continue
ax.scatter(coords[mask, 0], coords[mask, 1],
c=COMPARTMENT_COLORS[comp], s=4, alpha=0.75,
label=f"{comp} ({mask.sum()} bins)", linewidths=0,
rasterized=True)
ax.legend(markerscale=3, title="Compartment", fontsize=9,
title_fontsize=9, frameon=False)
ax.set_title(f"{label} — UMAP coloured by A/B compartment", fontsize=10)
ax.set_xlabel("UMAP 1", fontsize=9)
ax.set_ylabel("UMAP 2", fontsize=9)
ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
plt.tight_layout()
plt.savefig(out_path, dpi=300)
plt.close()
def _plot_joint(all_coords: np.ndarray, all_labels: list, out_path: str):
fig, ax = plt.subplots(figsize=(7, 6))
unique = list(dict.fromkeys(all_labels))
arr = np.array(all_labels)
for i, label in enumerate(unique):
mask = arr == label
ax.scatter(all_coords[mask, 0], all_coords[mask, 1],
c=CELL_LINE_PALETTE[i % len(CELL_LINE_PALETTE)],
s=3, alpha=0.6, label=label, linewidths=0, rasterized=True)
ax.legend(markerscale=4, title="Cell line", fontsize=9,
title_fontsize=9, frameon=False)
ax.set_title("Joint UMAP — chromatin topology embeddings", fontsize=11)
ax.set_xlabel("UMAP 1", fontsize=9)
ax.set_ylabel("UMAP 2", fontsize=9)
ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
plt.tight_layout()
plt.savefig(out_path, dpi=300)
plt.close()
def _silhouette(emb: np.ndarray, compartments: np.ndarray) -> float:
mask = compartments != "N"
if mask.sum() < 20 or len(set(compartments[mask])) < 2:
return float("nan")
try:
return float(silhouette_score(emb[mask], compartments[mask], metric="cosine"))
except Exception:
return float("nan")
def main():
p = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
)
p.add_argument("--emb", nargs="+", required=True,
help="One or more .npy embedding files")
p.add_argument("--labels", nargs="+", required=True,
help="Label for each embedding (same order)")
p.add_argument("--compartments", nargs="+",
help="Compartment CSV files, one per embedding (optional)")
p.add_argument("--prefix", default="results/figures/umap",
help="Output file prefix")
p.add_argument("--seed", type=int, default=42)
args = p.parse_args()
if len(args.emb) != len(args.labels):
raise ValueError("--emb and --labels must have the same length")
os.makedirs(os.path.dirname(os.path.abspath(args.prefix + "_x")), exist_ok=True)
embs = [np.load(f) for f in args.emb]
comp_dfs = []
if args.compartments:
for f in args.compartments:
comp_dfs.append(pd.read_csv(f) if (f and os.path.exists(f)) else None)
stats_rows = []
for i, (emb, label) in enumerate(zip(embs, args.labels)):
print(f"\n[{label}] {emb.shape[0]} nodes × {emb.shape[1]} dims")
coords = _run_umap(emb, seed=args.seed)
tag = label.replace(" ", "_")
_plot_position(coords, label, f"{args.prefix}_{tag}_position.png")
print(f"{args.prefix}_{tag}_position.png")
comp_arr = None
sil = float("nan")
if comp_dfs and i < len(comp_dfs) and comp_dfs[i] is not None:
comp_arr = comp_dfs[i]["compartment"].values[: len(emb)]
_plot_compartment(coords, comp_arr, label,
f"{args.prefix}_{tag}_compartment.png")
print(f"{args.prefix}_{tag}_compartment.png")
sil = _silhouette(emb, comp_arr)
print(f" Silhouette (A/B, cosine): {sil:.4f}")
stats_rows.append({
"label": label,
"n_bins": emb.shape[0],
"latent_dim": emb.shape[1],
"mean_embedding_norm": float(np.linalg.norm(emb, axis=1).mean()),
"std_embedding_values": float(emb.std()),
"silhouette_AB_cosine": sil,
})
# Joint UMAP when multiple embeddings are supplied
if len(embs) > 1:
print("\nComputing joint UMAP…")
all_emb = np.vstack(embs)
all_labels = sum([[lab] * len(e) for lab, e in zip(args.labels, embs)], [])
all_coords = _run_umap(all_emb, seed=args.seed)
_plot_joint(all_coords, all_labels, f"{args.prefix}_joint.png")
print(f"{args.prefix}_joint.png")
stats_df = pd.DataFrame(stats_rows)
stats_path = f"{args.prefix}_stats.csv"
stats_df.to_csv(stats_path, index=False)
print(f"\nStats → {stats_path}")
print(stats_df.to_string(index=False))
if __name__ == "__main__":
main()