v1.0.0: VGAE applied to GM12878 vs IMR90 chr21 Hi-C at 25kb
Full reproducible pipeline: .mcool + ChIP-seq bigwigs → latent embeddings → A/B compartment calls → cross-cell comparison. Key results (chr21, 25 kb, latent dim=32): - Test AUC=0.777, AP=0.759 (converged epoch 31/300) - GM12878 A/B silhouette (cosine) = 0.775 - IMR90 zero-shot silhouette = 0.443 - A-compartment bins stable across cell types (mean cosine Δ=0.042) - B-compartment bins shift substantially (mean cosine Δ=0.451) - 101 B→A and 70 A→B compartment switches GM12878→IMR90
This commit is contained in:
189
scripts/visualize_embeddings.py
Normal file
189
scripts/visualize_embeddings.py
Normal file
@@ -0,0 +1,189 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Visualise VGAE node embeddings using UMAP.
|
||||
|
||||
Produces (under --prefix):
|
||||
{prefix}_{label}_position.png UMAP coloured by genomic position (bin index)
|
||||
{prefix}_{label}_compartment.png UMAP coloured by A/B compartment (needs --compartments)
|
||||
{prefix}_joint.png Joint UMAP of all supplied cell lines
|
||||
{prefix}_stats.csv Per-embedding summary statistics
|
||||
|
||||
Usage
|
||||
-----
|
||||
python scripts/visualize_embeddings.py \\
|
||||
--emb results/GM12878/emb.npy results/IMR90/emb.npy \\
|
||||
--labels GM12878 IMR90 \\
|
||||
--compartments results/GM12878/compartments_chr21.csv \\
|
||||
results/IMR90/compartments_chr21.csv \\
|
||||
--prefix results/figures/umap
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import matplotlib
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.metrics import silhouette_score
|
||||
import umap
|
||||
|
||||
|
||||
COMPARTMENT_COLORS = {"A": "#E41A1C", "B": "#377EB8", "N": "#AAAAAA"}
|
||||
CELL_LINE_PALETTE = ["#E41A1C", "#4DAF4A", "#984EA3", "#FF7F00", "#377EB8"]
|
||||
|
||||
plt.rcParams.update({
|
||||
"font.family": "sans-serif",
|
||||
"axes.spines.top": False,
|
||||
"axes.spines.right": False,
|
||||
})
|
||||
|
||||
|
||||
def _run_umap(emb: np.ndarray, seed: int = 42) -> np.ndarray:
|
||||
reducer = umap.UMAP(n_components=2, random_state=seed,
|
||||
min_dist=0.3, n_neighbors=15)
|
||||
return reducer.fit_transform(emb)
|
||||
|
||||
|
||||
def _plot_position(coords: np.ndarray, label: str, out_path: str):
|
||||
fig, ax = plt.subplots(figsize=(6.5, 5.5))
|
||||
sc = ax.scatter(coords[:, 0], coords[:, 1],
|
||||
c=np.arange(len(coords)), cmap="plasma",
|
||||
s=4, alpha=0.75, linewidths=0, rasterized=True)
|
||||
cbar = plt.colorbar(sc, ax=ax, shrink=0.8, pad=0.02)
|
||||
cbar.set_label("Bin index (5′ → 3′)", fontsize=9)
|
||||
ax.set_title(f"{label} — UMAP coloured by genomic position", fontsize=10)
|
||||
ax.set_xlabel("UMAP 1", fontsize=9)
|
||||
ax.set_ylabel("UMAP 2", fontsize=9)
|
||||
ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
|
||||
plt.tight_layout()
|
||||
plt.savefig(out_path, dpi=300)
|
||||
plt.close()
|
||||
|
||||
|
||||
def _plot_compartment(coords: np.ndarray, compartments: np.ndarray,
|
||||
label: str, out_path: str):
|
||||
fig, ax = plt.subplots(figsize=(6.5, 5.5))
|
||||
for comp in ("A", "B", "N"):
|
||||
mask = compartments == comp
|
||||
if mask.sum() == 0:
|
||||
continue
|
||||
ax.scatter(coords[mask, 0], coords[mask, 1],
|
||||
c=COMPARTMENT_COLORS[comp], s=4, alpha=0.75,
|
||||
label=f"{comp} ({mask.sum()} bins)", linewidths=0,
|
||||
rasterized=True)
|
||||
ax.legend(markerscale=3, title="Compartment", fontsize=9,
|
||||
title_fontsize=9, frameon=False)
|
||||
ax.set_title(f"{label} — UMAP coloured by A/B compartment", fontsize=10)
|
||||
ax.set_xlabel("UMAP 1", fontsize=9)
|
||||
ax.set_ylabel("UMAP 2", fontsize=9)
|
||||
ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
|
||||
plt.tight_layout()
|
||||
plt.savefig(out_path, dpi=300)
|
||||
plt.close()
|
||||
|
||||
|
||||
def _plot_joint(all_coords: np.ndarray, all_labels: list, out_path: str):
|
||||
fig, ax = plt.subplots(figsize=(7, 6))
|
||||
unique = list(dict.fromkeys(all_labels))
|
||||
arr = np.array(all_labels)
|
||||
for i, label in enumerate(unique):
|
||||
mask = arr == label
|
||||
ax.scatter(all_coords[mask, 0], all_coords[mask, 1],
|
||||
c=CELL_LINE_PALETTE[i % len(CELL_LINE_PALETTE)],
|
||||
s=3, alpha=0.6, label=label, linewidths=0, rasterized=True)
|
||||
ax.legend(markerscale=4, title="Cell line", fontsize=9,
|
||||
title_fontsize=9, frameon=False)
|
||||
ax.set_title("Joint UMAP — chromatin topology embeddings", fontsize=11)
|
||||
ax.set_xlabel("UMAP 1", fontsize=9)
|
||||
ax.set_ylabel("UMAP 2", fontsize=9)
|
||||
ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
|
||||
plt.tight_layout()
|
||||
plt.savefig(out_path, dpi=300)
|
||||
plt.close()
|
||||
|
||||
|
||||
def _silhouette(emb: np.ndarray, compartments: np.ndarray) -> float:
|
||||
mask = compartments != "N"
|
||||
if mask.sum() < 20 or len(set(compartments[mask])) < 2:
|
||||
return float("nan")
|
||||
try:
|
||||
return float(silhouette_score(emb[mask], compartments[mask], metric="cosine"))
|
||||
except Exception:
|
||||
return float("nan")
|
||||
|
||||
|
||||
def main():
|
||||
p = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
|
||||
)
|
||||
p.add_argument("--emb", nargs="+", required=True,
|
||||
help="One or more .npy embedding files")
|
||||
p.add_argument("--labels", nargs="+", required=True,
|
||||
help="Label for each embedding (same order)")
|
||||
p.add_argument("--compartments", nargs="+",
|
||||
help="Compartment CSV files, one per embedding (optional)")
|
||||
p.add_argument("--prefix", default="results/figures/umap",
|
||||
help="Output file prefix")
|
||||
p.add_argument("--seed", type=int, default=42)
|
||||
args = p.parse_args()
|
||||
|
||||
if len(args.emb) != len(args.labels):
|
||||
raise ValueError("--emb and --labels must have the same length")
|
||||
|
||||
os.makedirs(os.path.dirname(os.path.abspath(args.prefix + "_x")), exist_ok=True)
|
||||
|
||||
embs = [np.load(f) for f in args.emb]
|
||||
comp_dfs = []
|
||||
if args.compartments:
|
||||
for f in args.compartments:
|
||||
comp_dfs.append(pd.read_csv(f) if (f and os.path.exists(f)) else None)
|
||||
|
||||
stats_rows = []
|
||||
|
||||
for i, (emb, label) in enumerate(zip(embs, args.labels)):
|
||||
print(f"\n[{label}] {emb.shape[0]} nodes × {emb.shape[1]} dims")
|
||||
coords = _run_umap(emb, seed=args.seed)
|
||||
|
||||
tag = label.replace(" ", "_")
|
||||
_plot_position(coords, label, f"{args.prefix}_{tag}_position.png")
|
||||
print(f" → {args.prefix}_{tag}_position.png")
|
||||
|
||||
comp_arr = None
|
||||
sil = float("nan")
|
||||
if comp_dfs and i < len(comp_dfs) and comp_dfs[i] is not None:
|
||||
comp_arr = comp_dfs[i]["compartment"].values[: len(emb)]
|
||||
_plot_compartment(coords, comp_arr, label,
|
||||
f"{args.prefix}_{tag}_compartment.png")
|
||||
print(f" → {args.prefix}_{tag}_compartment.png")
|
||||
sil = _silhouette(emb, comp_arr)
|
||||
print(f" Silhouette (A/B, cosine): {sil:.4f}")
|
||||
|
||||
stats_rows.append({
|
||||
"label": label,
|
||||
"n_bins": emb.shape[0],
|
||||
"latent_dim": emb.shape[1],
|
||||
"mean_embedding_norm": float(np.linalg.norm(emb, axis=1).mean()),
|
||||
"std_embedding_values": float(emb.std()),
|
||||
"silhouette_AB_cosine": sil,
|
||||
})
|
||||
|
||||
# Joint UMAP when multiple embeddings are supplied
|
||||
if len(embs) > 1:
|
||||
print("\nComputing joint UMAP…")
|
||||
all_emb = np.vstack(embs)
|
||||
all_labels = sum([[lab] * len(e) for lab, e in zip(args.labels, embs)], [])
|
||||
all_coords = _run_umap(all_emb, seed=args.seed)
|
||||
_plot_joint(all_coords, all_labels, f"{args.prefix}_joint.png")
|
||||
print(f" → {args.prefix}_joint.png")
|
||||
|
||||
stats_df = pd.DataFrame(stats_rows)
|
||||
stats_path = f"{args.prefix}_stats.csv"
|
||||
stats_df.to_csv(stats_path, index=False)
|
||||
print(f"\nStats → {stats_path}")
|
||||
print(stats_df.to_string(index=False))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user