Files
chromatin-vgae-hic/experiments/h2_rewiring/perturbation_viz.py

239 lines
9.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
Step 8: Perturbation experiment visualisations.
Generates four figures:
1. UMAP of control vs treated embeddings (coloured by condition)
2. Genome browser-style long-range drift track along chr1
3. Scatter: long-range drift (x) vs short-range drift (y), loop anchors highlighted
4. Bar plot: mean long-range drift at loop anchors vs non-anchors
Usage
-----
python scripts/perturbation_viz.py \
--control_emb results/perturbation/control_emb.npy \
--treated_emb results/perturbation/treated_emb.npy \
--drift_full results/perturbation/drift_full.npy \
--drift_short results/perturbation/drift_short.npy \
--drift_long results/perturbation/drift_long.npy \
--anchor_mask results/perturbation/anchor_mask.npy \
--stats results/perturbation/drift_stats.json \
--outdir results/perturbation/figures \
--res 25000
"""
import argparse
import json
import os
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import numpy as np
import umap
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _savefig(fig, path, dpi=150):
os.makedirs(os.path.dirname(path), exist_ok=True)
fig.savefig(path, dpi=dpi, bbox_inches="tight")
plt.close(fig)
print(f"Saved → {path}")
# ---------------------------------------------------------------------------
# Figure 1: UMAP control vs treated
# ---------------------------------------------------------------------------
def fig_umap(ctrl_emb, trt_emb, outdir):
print("Fitting UMAP ...")
combined = np.vstack([ctrl_emb, trt_emb])
n = ctrl_emb.shape[0]
reducer = umap.UMAP(n_components=2, n_neighbors=30, min_dist=0.1,
metric="cosine", random_state=42)
xy = reducer.fit_transform(combined)
ctrl_xy = xy[:n]
trt_xy = xy[n:]
fig, ax = plt.subplots(figsize=(6, 5))
ax.scatter(ctrl_xy[:, 0], ctrl_xy[:, 1],
s=4, alpha=0.6, color="#2166ac", label="Control (untreated)")
ax.scatter(trt_xy[:, 0], trt_xy[:, 1],
s=4, alpha=0.6, color="#d6604d", label="Treated (6 h auxin)")
ax.set_xlabel("UMAP 1", fontsize=11)
ax.set_ylabel("UMAP 2", fontsize=11)
ax.set_title("VGAE embeddings: control vs RAD21-depleted\n"
"(HCT-116 G1/S, chr1, 25 kb)", fontsize=10)
ax.legend(fontsize=9, markerscale=3)
_savefig(fig, os.path.join(outdir, "umap_perturbation.png"))
# Save coordinates for downstream use
np.save(os.path.join(outdir, "umap_ctrl.npy"), ctrl_xy)
np.save(os.path.join(outdir, "umap_trt.npy"), trt_xy)
# ---------------------------------------------------------------------------
# Figure 2: Genome-browser drift track
# ---------------------------------------------------------------------------
def fig_drift_track(drift_long, drift_short, anchor_mask, res_bp, outdir):
n = len(drift_long)
positions = np.arange(n) * res_bp / 1e6 # Mbp
fig, axes = plt.subplots(3, 1, figsize=(14, 7), sharex=True,
gridspec_kw={"height_ratios": [3, 3, 1]})
# Long-range drift
ax = axes[0]
ax.fill_between(positions, drift_short, alpha=0.7, color="#d6604d",
linewidth=0)
ax.set_ylabel("Short-range drift\n(< 1 Mb; signal)", fontsize=9)
ax.set_title("Embedding drift along chr1 — HCT-116 RAD21 depletion (6 h auxin)\n"
"Short-range = loop scale (expected HIGH at anchors); "
"Long-range = sub-compartment (expected LOW)",
fontsize=9)
ymax = max(drift_long.max(), 0.01) * 1.1
ax.set_ylim(0, ymax)
# Short-range drift
ax = axes[1]
ax.fill_between(positions, drift_long, alpha=0.7, color="#4393c3",
linewidth=0)
ax.set_ylabel("Long-range drift\n(25 Mb; background)", fontsize=9)
ymax = max(drift_short.max(), 0.01) * 1.1
ax.set_ylim(0, ymax)
# Anchor track
ax = axes[2]
ax.fill_between(positions, anchor_mask.astype(float),
color="#1b7837", alpha=0.8, linewidth=0)
ax.set_ylabel("Loop\nanchor", fontsize=9)
ax.set_ylim(0, 1.2)
ax.set_yticks([])
ax.set_xlabel("chr1 position (Mb)", fontsize=10)
plt.tight_layout()
_savefig(fig, os.path.join(outdir, "drift_track_chr1.png"))
# ---------------------------------------------------------------------------
# Figure 3: Long-range vs short-range scatter
# ---------------------------------------------------------------------------
def fig_scatter(drift_long, drift_short, anchor_mask, outdir):
non_anchor = ~anchor_mask
fig, ax = plt.subplots(figsize=(5.5, 5))
ax.scatter(drift_long[non_anchor], drift_short[non_anchor],
s=3, alpha=0.3, color="#aaaaaa", label="Non-anchor", rasterized=True)
ax.scatter(drift_long[anchor_mask], drift_short[anchor_mask],
s=10, alpha=0.8, color="#d6604d",
label=f"Loop anchor (n={anchor_mask.sum()})", rasterized=True)
ax.set_xlabel("Long-range cosine drift (25 Mb; sub-compartment scale)", fontsize=9)
ax.set_ylabel("Short-range cosine drift (< 1 Mb; loop scale)", fontsize=9)
ax.set_title("Loop-scale vs sub-compartment-scale drift\n"
"HCT-116 RAD21 depletion — CTCF anchors expected upper-left",
fontsize=9)
ax.legend(fontsize=9, markerscale=2)
# Diagonal guide
lim = max(drift_long.max(), drift_short.max()) * 1.05
ax.set_xlim(0, lim)
ax.set_ylim(0, lim)
ax.plot([0, lim], [0, lim], "k--", lw=0.8, alpha=0.3)
_savefig(fig, os.path.join(outdir, "scatter_lr_vs_sr_drift.png"))
# ---------------------------------------------------------------------------
# Figure 4: Bar plot with significance
# ---------------------------------------------------------------------------
def fig_barplot(drift_short, anchor_mask, stats_path, outdir):
anchor_vals = drift_short[anchor_mask]
non_vals = drift_short[~anchor_mask]
means = [anchor_vals.mean(), non_vals.mean()]
sems = [anchor_vals.std() / np.sqrt(len(anchor_vals)),
non_vals.std() / np.sqrt(len(non_vals))]
# Load p-value
p_val = None
if os.path.exists(stats_path):
with open(stats_path) as f:
p_val = json.load(f).get("perm_p_value")
fig, ax = plt.subplots(figsize=(4, 4.5))
colors = ["#d6604d", "#aaaaaa"]
bars = ax.bar(["Loop anchors\n(CTCF bins)", "Non-anchors"], means,
yerr=sems, capsize=5, color=colors,
edgecolor="black", linewidth=0.8, error_kw={"linewidth": 1.5})
ax.set_ylabel("Mean short-range cosine drift (< 1 Mb)", fontsize=10)
ax.set_title("Short-range drift enrichment\nat loop anchors\n"
"(loop-scale contacts lost after RAD21 depletion)", fontsize=9)
# Significance annotation
if p_val is not None:
ymax = max(m + s for m, s in zip(means, sems)) * 1.15
ax.set_ylim(0, ymax * 1.2)
ax.plot([0, 1], [ymax, ymax], color="black", lw=1.2)
sig = ("***" if p_val < 0.001 else
"**" if p_val < 0.01 else
"*" if p_val < 0.05 else
f"p={p_val:.3f}")
ax.text(0.5, ymax * 1.03, sig, ha="center", va="bottom", fontsize=12)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
_savefig(fig, os.path.join(outdir, "barplot_anchor_drift.png"))
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
ap = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter)
ap.add_argument("--control_emb", required=True)
ap.add_argument("--treated_emb", required=True)
ap.add_argument("--drift_full", required=True)
ap.add_argument("--drift_short", required=True)
ap.add_argument("--drift_long", required=True)
ap.add_argument("--anchor_mask", required=True)
ap.add_argument("--stats", required=True, help="drift_stats.json")
ap.add_argument("--outdir", default="results/perturbation/figures")
ap.add_argument("--res", type=int, default=25_000,
help="Bin size in bp")
args = ap.parse_args()
os.makedirs(args.outdir, exist_ok=True)
ctrl_emb = np.load(args.control_emb)
trt_emb = np.load(args.treated_emb)
drift_full = np.load(args.drift_full)
drift_short = np.load(args.drift_short)
drift_long = np.load(args.drift_long)
anchor_mask = np.load(args.anchor_mask)
print(f"Embeddings: {ctrl_emb.shape} Bins: {len(drift_long)}")
print(f"Anchor bins: {anchor_mask.sum()} / {len(anchor_mask)}")
fig_umap(ctrl_emb, trt_emb, args.outdir)
fig_drift_track(drift_long, drift_short, anchor_mask, args.res, args.outdir)
fig_scatter(drift_long, drift_short, anchor_mask, args.outdir)
fig_barplot(drift_short, anchor_mask, args.stats, args.outdir)
print("\nAll figures saved.")
if __name__ == "__main__":
main()