#!/usr/bin/env python3 """ Step 8: Perturbation experiment visualisations. Generates four figures: 1. UMAP of control vs treated embeddings (coloured by condition) 2. Genome browser-style long-range drift track along chr1 3. Scatter: long-range drift (x) vs short-range drift (y), loop anchors highlighted 4. Bar plot: mean long-range drift at loop anchors vs non-anchors Usage ----- python scripts/perturbation_viz.py \ --control_emb results/perturbation/control_emb.npy \ --treated_emb results/perturbation/treated_emb.npy \ --drift_full results/perturbation/drift_full.npy \ --drift_short results/perturbation/drift_short.npy \ --drift_long results/perturbation/drift_long.npy \ --anchor_mask results/perturbation/anchor_mask.npy \ --stats results/perturbation/drift_stats.json \ --outdir results/perturbation/figures \ --res 25000 """ import argparse import json import os import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import matplotlib.patches as mpatches import numpy as np import umap # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _savefig(fig, path, dpi=150): os.makedirs(os.path.dirname(path), exist_ok=True) fig.savefig(path, dpi=dpi, bbox_inches="tight") plt.close(fig) print(f"Saved → {path}") # --------------------------------------------------------------------------- # Figure 1: UMAP control vs treated # --------------------------------------------------------------------------- def fig_umap(ctrl_emb, trt_emb, outdir): print("Fitting UMAP ...") combined = np.vstack([ctrl_emb, trt_emb]) n = ctrl_emb.shape[0] reducer = umap.UMAP(n_components=2, n_neighbors=30, min_dist=0.1, metric="cosine", random_state=42) xy = reducer.fit_transform(combined) ctrl_xy = xy[:n] trt_xy = xy[n:] fig, ax = plt.subplots(figsize=(6, 5)) ax.scatter(ctrl_xy[:, 0], ctrl_xy[:, 1], s=4, alpha=0.6, color="#2166ac", label="Control (untreated)") ax.scatter(trt_xy[:, 0], trt_xy[:, 1], s=4, alpha=0.6, color="#d6604d", label="Treated (6 h auxin)") ax.set_xlabel("UMAP 1", fontsize=11) ax.set_ylabel("UMAP 2", fontsize=11) ax.set_title("VGAE embeddings: control vs RAD21-depleted\n" "(HCT-116 G1/S, chr1, 25 kb)", fontsize=10) ax.legend(fontsize=9, markerscale=3) _savefig(fig, os.path.join(outdir, "umap_perturbation.png")) # Save coordinates for downstream use np.save(os.path.join(outdir, "umap_ctrl.npy"), ctrl_xy) np.save(os.path.join(outdir, "umap_trt.npy"), trt_xy) # --------------------------------------------------------------------------- # Figure 2: Genome-browser drift track # --------------------------------------------------------------------------- def fig_drift_track(drift_long, drift_short, anchor_mask, res_bp, outdir): n = len(drift_long) positions = np.arange(n) * res_bp / 1e6 # Mbp fig, axes = plt.subplots(3, 1, figsize=(14, 7), sharex=True, gridspec_kw={"height_ratios": [3, 3, 1]}) # Long-range drift ax = axes[0] ax.fill_between(positions, drift_short, alpha=0.7, color="#d6604d", linewidth=0) ax.set_ylabel("Short-range drift\n(< 1 Mb; signal)", fontsize=9) ax.set_title("Embedding drift along chr1 — HCT-116 RAD21 depletion (6 h auxin)\n" "Short-range = loop scale (expected HIGH at anchors); " "Long-range = sub-compartment (expected LOW)", fontsize=9) ymax = max(drift_long.max(), 0.01) * 1.1 ax.set_ylim(0, ymax) # Short-range drift ax = axes[1] ax.fill_between(positions, drift_long, alpha=0.7, color="#4393c3", linewidth=0) ax.set_ylabel("Long-range drift\n(2–5 Mb; background)", fontsize=9) ymax = max(drift_short.max(), 0.01) * 1.1 ax.set_ylim(0, ymax) # Anchor track ax = axes[2] ax.fill_between(positions, anchor_mask.astype(float), color="#1b7837", alpha=0.8, linewidth=0) ax.set_ylabel("Loop\nanchor", fontsize=9) ax.set_ylim(0, 1.2) ax.set_yticks([]) ax.set_xlabel("chr1 position (Mb)", fontsize=10) plt.tight_layout() _savefig(fig, os.path.join(outdir, "drift_track_chr1.png")) # --------------------------------------------------------------------------- # Figure 3: Long-range vs short-range scatter # --------------------------------------------------------------------------- def fig_scatter(drift_long, drift_short, anchor_mask, outdir): non_anchor = ~anchor_mask fig, ax = plt.subplots(figsize=(5.5, 5)) ax.scatter(drift_long[non_anchor], drift_short[non_anchor], s=3, alpha=0.3, color="#aaaaaa", label="Non-anchor", rasterized=True) ax.scatter(drift_long[anchor_mask], drift_short[anchor_mask], s=10, alpha=0.8, color="#d6604d", label=f"Loop anchor (n={anchor_mask.sum()})", rasterized=True) ax.set_xlabel("Long-range cosine drift (2–5 Mb; sub-compartment scale)", fontsize=9) ax.set_ylabel("Short-range cosine drift (< 1 Mb; loop scale)", fontsize=9) ax.set_title("Loop-scale vs sub-compartment-scale drift\n" "HCT-116 RAD21 depletion — CTCF anchors expected upper-left", fontsize=9) ax.legend(fontsize=9, markerscale=2) # Diagonal guide lim = max(drift_long.max(), drift_short.max()) * 1.05 ax.set_xlim(0, lim) ax.set_ylim(0, lim) ax.plot([0, lim], [0, lim], "k--", lw=0.8, alpha=0.3) _savefig(fig, os.path.join(outdir, "scatter_lr_vs_sr_drift.png")) # --------------------------------------------------------------------------- # Figure 4: Bar plot with significance # --------------------------------------------------------------------------- def fig_barplot(drift_short, anchor_mask, stats_path, outdir): anchor_vals = drift_short[anchor_mask] non_vals = drift_short[~anchor_mask] means = [anchor_vals.mean(), non_vals.mean()] sems = [anchor_vals.std() / np.sqrt(len(anchor_vals)), non_vals.std() / np.sqrt(len(non_vals))] # Load p-value p_val = None if os.path.exists(stats_path): with open(stats_path) as f: p_val = json.load(f).get("perm_p_value") fig, ax = plt.subplots(figsize=(4, 4.5)) colors = ["#d6604d", "#aaaaaa"] bars = ax.bar(["Loop anchors\n(CTCF bins)", "Non-anchors"], means, yerr=sems, capsize=5, color=colors, edgecolor="black", linewidth=0.8, error_kw={"linewidth": 1.5}) ax.set_ylabel("Mean short-range cosine drift (< 1 Mb)", fontsize=10) ax.set_title("Short-range drift enrichment\nat loop anchors\n" "(loop-scale contacts lost after RAD21 depletion)", fontsize=9) # Significance annotation if p_val is not None: ymax = max(m + s for m, s in zip(means, sems)) * 1.15 ax.set_ylim(0, ymax * 1.2) ax.plot([0, 1], [ymax, ymax], color="black", lw=1.2) sig = ("***" if p_val < 0.001 else "**" if p_val < 0.01 else "*" if p_val < 0.05 else f"p={p_val:.3f}") ax.text(0.5, ymax * 1.03, sig, ha="center", va="bottom", fontsize=12) ax.spines["top"].set_visible(False) ax.spines["right"].set_visible(False) _savefig(fig, os.path.join(outdir, "barplot_anchor_drift.png")) # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main(): ap = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) ap.add_argument("--control_emb", required=True) ap.add_argument("--treated_emb", required=True) ap.add_argument("--drift_full", required=True) ap.add_argument("--drift_short", required=True) ap.add_argument("--drift_long", required=True) ap.add_argument("--anchor_mask", required=True) ap.add_argument("--stats", required=True, help="drift_stats.json") ap.add_argument("--outdir", default="results/perturbation/figures") ap.add_argument("--res", type=int, default=25_000, help="Bin size in bp") args = ap.parse_args() os.makedirs(args.outdir, exist_ok=True) ctrl_emb = np.load(args.control_emb) trt_emb = np.load(args.treated_emb) drift_full = np.load(args.drift_full) drift_short = np.load(args.drift_short) drift_long = np.load(args.drift_long) anchor_mask = np.load(args.anchor_mask) print(f"Embeddings: {ctrl_emb.shape} Bins: {len(drift_long)}") print(f"Anchor bins: {anchor_mask.sum()} / {len(anchor_mask)}") fig_umap(ctrl_emb, trt_emb, args.outdir) fig_drift_track(drift_long, drift_short, anchor_mask, args.res, args.outdir) fig_scatter(drift_long, drift_short, anchor_mask, args.outdir) fig_barplot(drift_short, anchor_mask, args.stats, args.outdir) print("\nAll figures saved.") if __name__ == "__main__": main()