Full reproducible pipeline: .mcool + ChIP-seq bigwigs → latent embeddings → A/B compartment calls → cross-cell comparison. Key results (chr21, 25 kb, latent dim=32): - Test AUC=0.777, AP=0.759 (converged epoch 31/300) - GM12878 A/B silhouette (cosine) = 0.775 - IMR90 zero-shot silhouette = 0.443 - A-compartment bins stable across cell types (mean cosine Δ=0.042) - B-compartment bins shift substantially (mean cosine Δ=0.451) - 101 B→A and 70 A→B compartment switches GM12878→IMR90
190 lines
7.2 KiB
Python
190 lines
7.2 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
Visualise VGAE node embeddings using UMAP.
|
||
|
||
Produces (under --prefix):
|
||
{prefix}_{label}_position.png UMAP coloured by genomic position (bin index)
|
||
{prefix}_{label}_compartment.png UMAP coloured by A/B compartment (needs --compartments)
|
||
{prefix}_joint.png Joint UMAP of all supplied cell lines
|
||
{prefix}_stats.csv Per-embedding summary statistics
|
||
|
||
Usage
|
||
-----
|
||
python scripts/visualize_embeddings.py \\
|
||
--emb results/GM12878/emb.npy results/IMR90/emb.npy \\
|
||
--labels GM12878 IMR90 \\
|
||
--compartments results/GM12878/compartments_chr21.csv \\
|
||
results/IMR90/compartments_chr21.csv \\
|
||
--prefix results/figures/umap
|
||
"""
|
||
|
||
import argparse
|
||
import os
|
||
|
||
import matplotlib
|
||
matplotlib.use("Agg")
|
||
import matplotlib.pyplot as plt
|
||
import numpy as np
|
||
import pandas as pd
|
||
from sklearn.metrics import silhouette_score
|
||
import umap
|
||
|
||
|
||
COMPARTMENT_COLORS = {"A": "#E41A1C", "B": "#377EB8", "N": "#AAAAAA"}
|
||
CELL_LINE_PALETTE = ["#E41A1C", "#4DAF4A", "#984EA3", "#FF7F00", "#377EB8"]
|
||
|
||
plt.rcParams.update({
|
||
"font.family": "sans-serif",
|
||
"axes.spines.top": False,
|
||
"axes.spines.right": False,
|
||
})
|
||
|
||
|
||
def _run_umap(emb: np.ndarray, seed: int = 42) -> np.ndarray:
|
||
reducer = umap.UMAP(n_components=2, random_state=seed,
|
||
min_dist=0.3, n_neighbors=15)
|
||
return reducer.fit_transform(emb)
|
||
|
||
|
||
def _plot_position(coords: np.ndarray, label: str, out_path: str):
|
||
fig, ax = plt.subplots(figsize=(6.5, 5.5))
|
||
sc = ax.scatter(coords[:, 0], coords[:, 1],
|
||
c=np.arange(len(coords)), cmap="plasma",
|
||
s=4, alpha=0.75, linewidths=0, rasterized=True)
|
||
cbar = plt.colorbar(sc, ax=ax, shrink=0.8, pad=0.02)
|
||
cbar.set_label("Bin index (5′ → 3′)", fontsize=9)
|
||
ax.set_title(f"{label} — UMAP coloured by genomic position", fontsize=10)
|
||
ax.set_xlabel("UMAP 1", fontsize=9)
|
||
ax.set_ylabel("UMAP 2", fontsize=9)
|
||
ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
|
||
plt.tight_layout()
|
||
plt.savefig(out_path, dpi=300)
|
||
plt.close()
|
||
|
||
|
||
def _plot_compartment(coords: np.ndarray, compartments: np.ndarray,
|
||
label: str, out_path: str):
|
||
fig, ax = plt.subplots(figsize=(6.5, 5.5))
|
||
for comp in ("A", "B", "N"):
|
||
mask = compartments == comp
|
||
if mask.sum() == 0:
|
||
continue
|
||
ax.scatter(coords[mask, 0], coords[mask, 1],
|
||
c=COMPARTMENT_COLORS[comp], s=4, alpha=0.75,
|
||
label=f"{comp} ({mask.sum()} bins)", linewidths=0,
|
||
rasterized=True)
|
||
ax.legend(markerscale=3, title="Compartment", fontsize=9,
|
||
title_fontsize=9, frameon=False)
|
||
ax.set_title(f"{label} — UMAP coloured by A/B compartment", fontsize=10)
|
||
ax.set_xlabel("UMAP 1", fontsize=9)
|
||
ax.set_ylabel("UMAP 2", fontsize=9)
|
||
ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
|
||
plt.tight_layout()
|
||
plt.savefig(out_path, dpi=300)
|
||
plt.close()
|
||
|
||
|
||
def _plot_joint(all_coords: np.ndarray, all_labels: list, out_path: str):
|
||
fig, ax = plt.subplots(figsize=(7, 6))
|
||
unique = list(dict.fromkeys(all_labels))
|
||
arr = np.array(all_labels)
|
||
for i, label in enumerate(unique):
|
||
mask = arr == label
|
||
ax.scatter(all_coords[mask, 0], all_coords[mask, 1],
|
||
c=CELL_LINE_PALETTE[i % len(CELL_LINE_PALETTE)],
|
||
s=3, alpha=0.6, label=label, linewidths=0, rasterized=True)
|
||
ax.legend(markerscale=4, title="Cell line", fontsize=9,
|
||
title_fontsize=9, frameon=False)
|
||
ax.set_title("Joint UMAP — chromatin topology embeddings", fontsize=11)
|
||
ax.set_xlabel("UMAP 1", fontsize=9)
|
||
ax.set_ylabel("UMAP 2", fontsize=9)
|
||
ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
|
||
plt.tight_layout()
|
||
plt.savefig(out_path, dpi=300)
|
||
plt.close()
|
||
|
||
|
||
def _silhouette(emb: np.ndarray, compartments: np.ndarray) -> float:
|
||
mask = compartments != "N"
|
||
if mask.sum() < 20 or len(set(compartments[mask])) < 2:
|
||
return float("nan")
|
||
try:
|
||
return float(silhouette_score(emb[mask], compartments[mask], metric="cosine"))
|
||
except Exception:
|
||
return float("nan")
|
||
|
||
|
||
def main():
|
||
p = argparse.ArgumentParser(
|
||
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
|
||
)
|
||
p.add_argument("--emb", nargs="+", required=True,
|
||
help="One or more .npy embedding files")
|
||
p.add_argument("--labels", nargs="+", required=True,
|
||
help="Label for each embedding (same order)")
|
||
p.add_argument("--compartments", nargs="+",
|
||
help="Compartment CSV files, one per embedding (optional)")
|
||
p.add_argument("--prefix", default="results/figures/umap",
|
||
help="Output file prefix")
|
||
p.add_argument("--seed", type=int, default=42)
|
||
args = p.parse_args()
|
||
|
||
if len(args.emb) != len(args.labels):
|
||
raise ValueError("--emb and --labels must have the same length")
|
||
|
||
os.makedirs(os.path.dirname(os.path.abspath(args.prefix + "_x")), exist_ok=True)
|
||
|
||
embs = [np.load(f) for f in args.emb]
|
||
comp_dfs = []
|
||
if args.compartments:
|
||
for f in args.compartments:
|
||
comp_dfs.append(pd.read_csv(f) if (f and os.path.exists(f)) else None)
|
||
|
||
stats_rows = []
|
||
|
||
for i, (emb, label) in enumerate(zip(embs, args.labels)):
|
||
print(f"\n[{label}] {emb.shape[0]} nodes × {emb.shape[1]} dims")
|
||
coords = _run_umap(emb, seed=args.seed)
|
||
|
||
tag = label.replace(" ", "_")
|
||
_plot_position(coords, label, f"{args.prefix}_{tag}_position.png")
|
||
print(f" → {args.prefix}_{tag}_position.png")
|
||
|
||
comp_arr = None
|
||
sil = float("nan")
|
||
if comp_dfs and i < len(comp_dfs) and comp_dfs[i] is not None:
|
||
comp_arr = comp_dfs[i]["compartment"].values[: len(emb)]
|
||
_plot_compartment(coords, comp_arr, label,
|
||
f"{args.prefix}_{tag}_compartment.png")
|
||
print(f" → {args.prefix}_{tag}_compartment.png")
|
||
sil = _silhouette(emb, comp_arr)
|
||
print(f" Silhouette (A/B, cosine): {sil:.4f}")
|
||
|
||
stats_rows.append({
|
||
"label": label,
|
||
"n_bins": emb.shape[0],
|
||
"latent_dim": emb.shape[1],
|
||
"mean_embedding_norm": float(np.linalg.norm(emb, axis=1).mean()),
|
||
"std_embedding_values": float(emb.std()),
|
||
"silhouette_AB_cosine": sil,
|
||
})
|
||
|
||
# Joint UMAP when multiple embeddings are supplied
|
||
if len(embs) > 1:
|
||
print("\nComputing joint UMAP…")
|
||
all_emb = np.vstack(embs)
|
||
all_labels = sum([[lab] * len(e) for lab, e in zip(args.labels, embs)], [])
|
||
all_coords = _run_umap(all_emb, seed=args.seed)
|
||
_plot_joint(all_coords, all_labels, f"{args.prefix}_joint.png")
|
||
print(f" → {args.prefix}_joint.png")
|
||
|
||
stats_df = pd.DataFrame(stats_rows)
|
||
stats_path = f"{args.prefix}_stats.csv"
|
||
stats_df.to_csv(stats_path, index=False)
|
||
print(f"\nStats → {stats_path}")
|
||
print(stats_df.to_string(index=False))
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|