Files
chromatin-vgae-hic/scripts/visualize_embeddings.py
aman acadbd780c v1.0.0: VGAE applied to GM12878 vs IMR90 chr21 Hi-C at 25kb
Full reproducible pipeline: .mcool + ChIP-seq bigwigs → latent
  embeddings → A/B compartment calls → cross-cell comparison.

  Key results (chr21, 25 kb, latent dim=32):
  - Test AUC=0.777, AP=0.759 (converged epoch 31/300)
  - GM12878 A/B silhouette (cosine) = 0.775
  - IMR90 zero-shot silhouette = 0.443
  - A-compartment bins stable across cell types (mean cosine Δ=0.042)
  - B-compartment bins shift substantially (mean cosine Δ=0.451)
  - 101 B→A and 70 A→B compartment switches GM12878→IMR90
2026-05-15 01:53:04 +02:00

190 lines
7.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
Visualise VGAE node embeddings using UMAP.
Produces (under --prefix):
{prefix}_{label}_position.png UMAP coloured by genomic position (bin index)
{prefix}_{label}_compartment.png UMAP coloured by A/B compartment (needs --compartments)
{prefix}_joint.png Joint UMAP of all supplied cell lines
{prefix}_stats.csv Per-embedding summary statistics
Usage
-----
python scripts/visualize_embeddings.py \\
--emb results/GM12878/emb.npy results/IMR90/emb.npy \\
--labels GM12878 IMR90 \\
--compartments results/GM12878/compartments_chr21.csv \\
results/IMR90/compartments_chr21.csv \\
--prefix results/figures/umap
"""
import argparse
import os
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.metrics import silhouette_score
import umap
COMPARTMENT_COLORS = {"A": "#E41A1C", "B": "#377EB8", "N": "#AAAAAA"}
CELL_LINE_PALETTE = ["#E41A1C", "#4DAF4A", "#984EA3", "#FF7F00", "#377EB8"]
plt.rcParams.update({
"font.family": "sans-serif",
"axes.spines.top": False,
"axes.spines.right": False,
})
def _run_umap(emb: np.ndarray, seed: int = 42) -> np.ndarray:
reducer = umap.UMAP(n_components=2, random_state=seed,
min_dist=0.3, n_neighbors=15)
return reducer.fit_transform(emb)
def _plot_position(coords: np.ndarray, label: str, out_path: str):
fig, ax = plt.subplots(figsize=(6.5, 5.5))
sc = ax.scatter(coords[:, 0], coords[:, 1],
c=np.arange(len(coords)), cmap="plasma",
s=4, alpha=0.75, linewidths=0, rasterized=True)
cbar = plt.colorbar(sc, ax=ax, shrink=0.8, pad=0.02)
cbar.set_label("Bin index (5 → 3)", fontsize=9)
ax.set_title(f"{label} — UMAP coloured by genomic position", fontsize=10)
ax.set_xlabel("UMAP 1", fontsize=9)
ax.set_ylabel("UMAP 2", fontsize=9)
ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
plt.tight_layout()
plt.savefig(out_path, dpi=300)
plt.close()
def _plot_compartment(coords: np.ndarray, compartments: np.ndarray,
label: str, out_path: str):
fig, ax = plt.subplots(figsize=(6.5, 5.5))
for comp in ("A", "B", "N"):
mask = compartments == comp
if mask.sum() == 0:
continue
ax.scatter(coords[mask, 0], coords[mask, 1],
c=COMPARTMENT_COLORS[comp], s=4, alpha=0.75,
label=f"{comp} ({mask.sum()} bins)", linewidths=0,
rasterized=True)
ax.legend(markerscale=3, title="Compartment", fontsize=9,
title_fontsize=9, frameon=False)
ax.set_title(f"{label} — UMAP coloured by A/B compartment", fontsize=10)
ax.set_xlabel("UMAP 1", fontsize=9)
ax.set_ylabel("UMAP 2", fontsize=9)
ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
plt.tight_layout()
plt.savefig(out_path, dpi=300)
plt.close()
def _plot_joint(all_coords: np.ndarray, all_labels: list, out_path: str):
fig, ax = plt.subplots(figsize=(7, 6))
unique = list(dict.fromkeys(all_labels))
arr = np.array(all_labels)
for i, label in enumerate(unique):
mask = arr == label
ax.scatter(all_coords[mask, 0], all_coords[mask, 1],
c=CELL_LINE_PALETTE[i % len(CELL_LINE_PALETTE)],
s=3, alpha=0.6, label=label, linewidths=0, rasterized=True)
ax.legend(markerscale=4, title="Cell line", fontsize=9,
title_fontsize=9, frameon=False)
ax.set_title("Joint UMAP — chromatin topology embeddings", fontsize=11)
ax.set_xlabel("UMAP 1", fontsize=9)
ax.set_ylabel("UMAP 2", fontsize=9)
ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
plt.tight_layout()
plt.savefig(out_path, dpi=300)
plt.close()
def _silhouette(emb: np.ndarray, compartments: np.ndarray) -> float:
mask = compartments != "N"
if mask.sum() < 20 or len(set(compartments[mask])) < 2:
return float("nan")
try:
return float(silhouette_score(emb[mask], compartments[mask], metric="cosine"))
except Exception:
return float("nan")
def main():
p = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
)
p.add_argument("--emb", nargs="+", required=True,
help="One or more .npy embedding files")
p.add_argument("--labels", nargs="+", required=True,
help="Label for each embedding (same order)")
p.add_argument("--compartments", nargs="+",
help="Compartment CSV files, one per embedding (optional)")
p.add_argument("--prefix", default="results/figures/umap",
help="Output file prefix")
p.add_argument("--seed", type=int, default=42)
args = p.parse_args()
if len(args.emb) != len(args.labels):
raise ValueError("--emb and --labels must have the same length")
os.makedirs(os.path.dirname(os.path.abspath(args.prefix + "_x")), exist_ok=True)
embs = [np.load(f) for f in args.emb]
comp_dfs = []
if args.compartments:
for f in args.compartments:
comp_dfs.append(pd.read_csv(f) if (f and os.path.exists(f)) else None)
stats_rows = []
for i, (emb, label) in enumerate(zip(embs, args.labels)):
print(f"\n[{label}] {emb.shape[0]} nodes × {emb.shape[1]} dims")
coords = _run_umap(emb, seed=args.seed)
tag = label.replace(" ", "_")
_plot_position(coords, label, f"{args.prefix}_{tag}_position.png")
print(f"{args.prefix}_{tag}_position.png")
comp_arr = None
sil = float("nan")
if comp_dfs and i < len(comp_dfs) and comp_dfs[i] is not None:
comp_arr = comp_dfs[i]["compartment"].values[: len(emb)]
_plot_compartment(coords, comp_arr, label,
f"{args.prefix}_{tag}_compartment.png")
print(f"{args.prefix}_{tag}_compartment.png")
sil = _silhouette(emb, comp_arr)
print(f" Silhouette (A/B, cosine): {sil:.4f}")
stats_rows.append({
"label": label,
"n_bins": emb.shape[0],
"latent_dim": emb.shape[1],
"mean_embedding_norm": float(np.linalg.norm(emb, axis=1).mean()),
"std_embedding_values": float(emb.std()),
"silhouette_AB_cosine": sil,
})
# Joint UMAP when multiple embeddings are supplied
if len(embs) > 1:
print("\nComputing joint UMAP…")
all_emb = np.vstack(embs)
all_labels = sum([[lab] * len(e) for lab, e in zip(args.labels, embs)], [])
all_coords = _run_umap(all_emb, seed=args.seed)
_plot_joint(all_coords, all_labels, f"{args.prefix}_joint.png")
print(f"{args.prefix}_joint.png")
stats_df = pd.DataFrame(stats_rows)
stats_path = f"{args.prefix}_stats.csv"
stats_df.to_csv(stats_path, index=False)
print(f"\nStats → {stats_path}")
print(stats_df.to_string(index=False))
if __name__ == "__main__":
main()