239 lines
9.1 KiB
Python
239 lines
9.1 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
Step 8: Perturbation experiment visualisations.
|
||
|
||
Generates four figures:
|
||
1. UMAP of control vs treated embeddings (coloured by condition)
|
||
2. Genome browser-style long-range drift track along chr1
|
||
3. Scatter: long-range drift (x) vs short-range drift (y), loop anchors highlighted
|
||
4. Bar plot: mean long-range drift at loop anchors vs non-anchors
|
||
|
||
Usage
|
||
-----
|
||
python scripts/perturbation_viz.py \
|
||
--control_emb results/perturbation/control_emb.npy \
|
||
--treated_emb results/perturbation/treated_emb.npy \
|
||
--drift_full results/perturbation/drift_full.npy \
|
||
--drift_short results/perturbation/drift_short.npy \
|
||
--drift_long results/perturbation/drift_long.npy \
|
||
--anchor_mask results/perturbation/anchor_mask.npy \
|
||
--stats results/perturbation/drift_stats.json \
|
||
--outdir results/perturbation/figures \
|
||
--res 25000
|
||
"""
|
||
|
||
import argparse
|
||
import json
|
||
import os
|
||
|
||
import matplotlib
|
||
matplotlib.use("Agg")
|
||
import matplotlib.pyplot as plt
|
||
import matplotlib.patches as mpatches
|
||
import numpy as np
|
||
import umap
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Helpers
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def _savefig(fig, path, dpi=150):
|
||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||
fig.savefig(path, dpi=dpi, bbox_inches="tight")
|
||
plt.close(fig)
|
||
print(f"Saved → {path}")
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Figure 1: UMAP control vs treated
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def fig_umap(ctrl_emb, trt_emb, outdir):
|
||
print("Fitting UMAP ...")
|
||
combined = np.vstack([ctrl_emb, trt_emb])
|
||
n = ctrl_emb.shape[0]
|
||
|
||
reducer = umap.UMAP(n_components=2, n_neighbors=30, min_dist=0.1,
|
||
metric="cosine", random_state=42)
|
||
xy = reducer.fit_transform(combined)
|
||
|
||
ctrl_xy = xy[:n]
|
||
trt_xy = xy[n:]
|
||
|
||
fig, ax = plt.subplots(figsize=(6, 5))
|
||
ax.scatter(ctrl_xy[:, 0], ctrl_xy[:, 1],
|
||
s=4, alpha=0.6, color="#2166ac", label="Control (untreated)")
|
||
ax.scatter(trt_xy[:, 0], trt_xy[:, 1],
|
||
s=4, alpha=0.6, color="#d6604d", label="Treated (6 h auxin)")
|
||
ax.set_xlabel("UMAP 1", fontsize=11)
|
||
ax.set_ylabel("UMAP 2", fontsize=11)
|
||
ax.set_title("VGAE embeddings: control vs RAD21-depleted\n"
|
||
"(HCT-116 G1/S, chr1, 25 kb)", fontsize=10)
|
||
ax.legend(fontsize=9, markerscale=3)
|
||
_savefig(fig, os.path.join(outdir, "umap_perturbation.png"))
|
||
|
||
# Save coordinates for downstream use
|
||
np.save(os.path.join(outdir, "umap_ctrl.npy"), ctrl_xy)
|
||
np.save(os.path.join(outdir, "umap_trt.npy"), trt_xy)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Figure 2: Genome-browser drift track
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def fig_drift_track(drift_long, drift_short, anchor_mask, res_bp, outdir):
|
||
n = len(drift_long)
|
||
positions = np.arange(n) * res_bp / 1e6 # Mbp
|
||
|
||
fig, axes = plt.subplots(3, 1, figsize=(14, 7), sharex=True,
|
||
gridspec_kw={"height_ratios": [3, 3, 1]})
|
||
|
||
# Long-range drift
|
||
ax = axes[0]
|
||
ax.fill_between(positions, drift_short, alpha=0.7, color="#d6604d",
|
||
linewidth=0)
|
||
ax.set_ylabel("Short-range drift\n(< 1 Mb; signal)", fontsize=9)
|
||
ax.set_title("Embedding drift along chr1 — HCT-116 RAD21 depletion (6 h auxin)\n"
|
||
"Short-range = loop scale (expected HIGH at anchors); "
|
||
"Long-range = sub-compartment (expected LOW)",
|
||
fontsize=9)
|
||
ymax = max(drift_long.max(), 0.01) * 1.1
|
||
ax.set_ylim(0, ymax)
|
||
|
||
# Short-range drift
|
||
ax = axes[1]
|
||
ax.fill_between(positions, drift_long, alpha=0.7, color="#4393c3",
|
||
linewidth=0)
|
||
ax.set_ylabel("Long-range drift\n(2–5 Mb; background)", fontsize=9)
|
||
ymax = max(drift_short.max(), 0.01) * 1.1
|
||
ax.set_ylim(0, ymax)
|
||
|
||
# Anchor track
|
||
ax = axes[2]
|
||
ax.fill_between(positions, anchor_mask.astype(float),
|
||
color="#1b7837", alpha=0.8, linewidth=0)
|
||
ax.set_ylabel("Loop\nanchor", fontsize=9)
|
||
ax.set_ylim(0, 1.2)
|
||
ax.set_yticks([])
|
||
ax.set_xlabel("chr1 position (Mb)", fontsize=10)
|
||
|
||
plt.tight_layout()
|
||
_savefig(fig, os.path.join(outdir, "drift_track_chr1.png"))
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Figure 3: Long-range vs short-range scatter
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def fig_scatter(drift_long, drift_short, anchor_mask, outdir):
|
||
non_anchor = ~anchor_mask
|
||
|
||
fig, ax = plt.subplots(figsize=(5.5, 5))
|
||
ax.scatter(drift_long[non_anchor], drift_short[non_anchor],
|
||
s=3, alpha=0.3, color="#aaaaaa", label="Non-anchor", rasterized=True)
|
||
ax.scatter(drift_long[anchor_mask], drift_short[anchor_mask],
|
||
s=10, alpha=0.8, color="#d6604d",
|
||
label=f"Loop anchor (n={anchor_mask.sum()})", rasterized=True)
|
||
|
||
ax.set_xlabel("Long-range cosine drift (2–5 Mb; sub-compartment scale)", fontsize=9)
|
||
ax.set_ylabel("Short-range cosine drift (< 1 Mb; loop scale)", fontsize=9)
|
||
ax.set_title("Loop-scale vs sub-compartment-scale drift\n"
|
||
"HCT-116 RAD21 depletion — CTCF anchors expected upper-left",
|
||
fontsize=9)
|
||
ax.legend(fontsize=9, markerscale=2)
|
||
|
||
# Diagonal guide
|
||
lim = max(drift_long.max(), drift_short.max()) * 1.05
|
||
ax.set_xlim(0, lim)
|
||
ax.set_ylim(0, lim)
|
||
ax.plot([0, lim], [0, lim], "k--", lw=0.8, alpha=0.3)
|
||
|
||
_savefig(fig, os.path.join(outdir, "scatter_lr_vs_sr_drift.png"))
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Figure 4: Bar plot with significance
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def fig_barplot(drift_short, anchor_mask, stats_path, outdir):
|
||
anchor_vals = drift_short[anchor_mask]
|
||
non_vals = drift_short[~anchor_mask]
|
||
|
||
means = [anchor_vals.mean(), non_vals.mean()]
|
||
sems = [anchor_vals.std() / np.sqrt(len(anchor_vals)),
|
||
non_vals.std() / np.sqrt(len(non_vals))]
|
||
|
||
# Load p-value
|
||
p_val = None
|
||
if os.path.exists(stats_path):
|
||
with open(stats_path) as f:
|
||
p_val = json.load(f).get("perm_p_value")
|
||
|
||
fig, ax = plt.subplots(figsize=(4, 4.5))
|
||
colors = ["#d6604d", "#aaaaaa"]
|
||
bars = ax.bar(["Loop anchors\n(CTCF bins)", "Non-anchors"], means,
|
||
yerr=sems, capsize=5, color=colors,
|
||
edgecolor="black", linewidth=0.8, error_kw={"linewidth": 1.5})
|
||
ax.set_ylabel("Mean short-range cosine drift (< 1 Mb)", fontsize=10)
|
||
ax.set_title("Short-range drift enrichment\nat loop anchors\n"
|
||
"(loop-scale contacts lost after RAD21 depletion)", fontsize=9)
|
||
|
||
# Significance annotation
|
||
if p_val is not None:
|
||
ymax = max(m + s for m, s in zip(means, sems)) * 1.15
|
||
ax.set_ylim(0, ymax * 1.2)
|
||
ax.plot([0, 1], [ymax, ymax], color="black", lw=1.2)
|
||
sig = ("***" if p_val < 0.001 else
|
||
"**" if p_val < 0.01 else
|
||
"*" if p_val < 0.05 else
|
||
f"p={p_val:.3f}")
|
||
ax.text(0.5, ymax * 1.03, sig, ha="center", va="bottom", fontsize=12)
|
||
|
||
ax.spines["top"].set_visible(False)
|
||
ax.spines["right"].set_visible(False)
|
||
_savefig(fig, os.path.join(outdir, "barplot_anchor_drift.png"))
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Main
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def main():
|
||
ap = argparse.ArgumentParser(description=__doc__,
|
||
formatter_class=argparse.RawDescriptionHelpFormatter)
|
||
ap.add_argument("--control_emb", required=True)
|
||
ap.add_argument("--treated_emb", required=True)
|
||
ap.add_argument("--drift_full", required=True)
|
||
ap.add_argument("--drift_short", required=True)
|
||
ap.add_argument("--drift_long", required=True)
|
||
ap.add_argument("--anchor_mask", required=True)
|
||
ap.add_argument("--stats", required=True, help="drift_stats.json")
|
||
ap.add_argument("--outdir", default="results/perturbation/figures")
|
||
ap.add_argument("--res", type=int, default=25_000,
|
||
help="Bin size in bp")
|
||
args = ap.parse_args()
|
||
|
||
os.makedirs(args.outdir, exist_ok=True)
|
||
|
||
ctrl_emb = np.load(args.control_emb)
|
||
trt_emb = np.load(args.treated_emb)
|
||
drift_full = np.load(args.drift_full)
|
||
drift_short = np.load(args.drift_short)
|
||
drift_long = np.load(args.drift_long)
|
||
anchor_mask = np.load(args.anchor_mask)
|
||
|
||
print(f"Embeddings: {ctrl_emb.shape} Bins: {len(drift_long)}")
|
||
print(f"Anchor bins: {anchor_mask.sum()} / {len(anchor_mask)}")
|
||
|
||
fig_umap(ctrl_emb, trt_emb, args.outdir)
|
||
fig_drift_track(drift_long, drift_short, anchor_mask, args.res, args.outdir)
|
||
fig_scatter(drift_long, drift_short, anchor_mask, args.outdir)
|
||
fig_barplot(drift_short, anchor_mask, args.stats, args.outdir)
|
||
|
||
print("\nAll figures saved.")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|