#!/usr/bin/env python3 """ Perturbation drift analysis: Steps 5–7. Encodes control and treated graphs with a trained VGAE, then decomposes per-bin embedding drift into short-range and long-range components by re-encoding with distance-filtered edge sets. Validates whether short-range drift is enriched at loop anchor bins using a permutation test. Biology of RAD21/cohesin depletion (Rao et al. 2017): - Loops are lost: contacts at loop scale (~100 kb – 1 Mb) - Compartments are preserved: contacts at chromosome scale (>10 Mb) Therefore the expected drift pattern is: - SHORT-RANGE drift HIGH at former loop anchors (loop contacts gone) - LONG-RANGE drift LOW genome-wide (compartment contacts persist) NOTE: the graph max_dist filter (default 5 Mb) means the "long-range" band (2–5 Mb) is sub-compartment scale, not true compartment scale. We cannot directly observe compartment preservation with this graph setup. Inputs ------ control_graph.pt PyG Data object (control condition) treated_graph.pt PyG Data object (treated/perturbed condition) model.pt Trained VGAE state_dict (from train_vgae.py) loops.bedpe Loop anchor coordinates (bedpe, from call_loops.py) Outputs (under --outdir) ------------------------ control_emb.npy Full embeddings, control treated_emb.npy Full embeddings, treated drift_full.npy Per-bin cosine distance, full graph drift_short.npy Per-bin cosine distance, edges < short_cutoff drift_long.npy Per-bin cosine distance, edges > long_cutoff anchor_mask.npy Bool array: True = loop anchor bin drift_stats.json Summary statistics and permutation p-value """ import argparse import json import os import sys import numpy as np import torch from torch_geometric.nn.models import VGAE from torch_geometric.utils import remove_self_loops, to_undirected from chromatin_gnn.model import build_encoder # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _load_metrics(model_path: str) -> dict: p = os.path.join(os.path.dirname(os.path.abspath(model_path)), "metrics.json") if os.path.exists(p): with open(p) as f: return json.load(f) return {} def _encode(model, x, edge_index, edge_weight): model.eval() with torch.no_grad(): z = model.encode(x, edge_index, edge_weight) return z.cpu().numpy() def _cosine_dist(a: np.ndarray, b: np.ndarray) -> np.ndarray: """Row-wise cosine distance between two [N, D] arrays.""" norm_a = np.linalg.norm(a, axis=1, keepdims=True) + 1e-8 norm_b = np.linalg.norm(b, axis=1, keepdims=True) + 1e-8 sim = (a / norm_a * (b / norm_b)).sum(axis=1) return 1.0 - np.clip(sim, -1.0, 1.0) def _filter_edges_by_distance(edge_index: torch.Tensor, edge_weight: torch.Tensor | None, min_bins: int = 0, max_bins: int = int(1e9)): """Keep only edges whose bin-distance falls within [min_bins, max_bins).""" src, dst = edge_index[0], edge_index[1] dist = (src - dst).abs() mask = (dist >= min_bins) & (dist < max_bins) ei = edge_index[:, mask] ew = edge_weight[mask] if edge_weight is not None else None return ei, ew def _anchor_mask_from_bedpe(bedpe_path: str, chrom: str, n_bins: int, bin_size: int) -> np.ndarray: """Return bool array of length n_bins; True where bin overlaps a loop anchor.""" mask = np.zeros(n_bins, dtype=bool) try: import pandas as pd df = pd.read_csv(bedpe_path, sep="\t") # Use both anchor columns (chrom1/start1/end1 and chrom2/start2/end2) for side in [("chrom1", "start1", "end1"), ("chrom2", "start2", "end2")]: c, s, e = side if c not in df.columns: continue sub = df[df[c] == chrom] for _, row in sub.iterrows(): lo = int(row[s]) // bin_size hi = int(row[e]) // bin_size + 1 lo = max(0, lo) hi = min(n_bins, hi) mask[lo:hi] = True except Exception as ex: print(f"WARNING: could not parse {bedpe_path}: {ex}") return mask def _anchor_mask_from_ctcf(x: np.ndarray, ctcf_feat_idx: int = 0, percentile: float = 75.0) -> np.ndarray: """ Define loop anchor bins by CTCF ChIP-seq signal. CTCF marks the genomic anchor points of RAD21/cohesin-dependent loops. Bins above `percentile` of CTCF signal are treated as potential loop anchors. This proxy is used when Hi-C loop calls lack statistical power (e.g. due to low sequencing depth), and is independent of the contact data under test. Reference: Rao et al. 2017 — "CTCF is required for maintaining the architecture of cohesin-mediated chromatin loops." """ ctcf = x[:, ctcf_feat_idx] threshold = np.percentile(ctcf[ctcf > 0], percentile) if (ctcf > 0).any() else 0.0 mask = ctcf >= threshold return mask.astype(bool) def _permutation_pvalue(drift_lr: np.ndarray, anchor_mask: np.ndarray, n_perm: int = 1000, rng_seed: int = 42) -> tuple: """ One-sided label-permutation test. Null: long-range drift at randomly chosen bins equals that at loop anchors. Observed statistic: mean(drift_lr[anchor]) - mean(drift_lr[~anchor]). p-value: fraction of permutations where the shuffled statistic ≥ observed. """ rng = np.random.default_rng(rng_seed) n_anchors = anchor_mask.sum() if n_anchors == 0: return np.nan, np.nan, np.nan, [] obs_anchor = drift_lr[anchor_mask].mean() obs_non = drift_lr[~anchor_mask].mean() if (~anchor_mask).any() else np.nan obs_stat = obs_anchor - obs_non null_stats = [] for _ in range(n_perm): perm_idx = rng.choice(len(drift_lr), size=int(n_anchors), replace=False) perm_anchor = drift_lr[perm_idx].mean() perm_non = np.delete(drift_lr, perm_idx).mean() null_stats.append(perm_anchor - perm_non) null_arr = np.array(null_stats) p_val = float((null_arr >= obs_stat).mean()) return float(obs_anchor), float(obs_non), p_val, null_stats # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main(): ap = argparse.ArgumentParser( description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter, ) ap.add_argument("--control_graph", required=True) ap.add_argument("--treated_graph", required=True) ap.add_argument("--model", required=True, help="model.pt saved by train_vgae.py") ap.add_argument("--loops", default=None, help="Loop call bedpe (from call_loops.py). " "If omitted or if fewer than --min_loops loops are found, " "falls back to CTCF-signal anchor labelling.") ap.add_argument("--min_loops", type=int, default=10, help="Minimum bedpe loops required to use Hi-C anchor mode; " "below this threshold CTCF proxy is used (default 10)") ap.add_argument("--ctcf_percentile", type=float, default=75.0, help="CTCF signal percentile threshold for anchor labelling " "when falling back to CTCF proxy mode (default 75)") ap.add_argument("--chrom", default="chr1") ap.add_argument("--res", type=int, default=25_000, help="Bin size in bp (must match graph resolution)") ap.add_argument("--short_cutoff", type=int, default=1_000_000, help="Short-range upper bound in bp (default 1 Mb; " "captures full loop scale of ~100 kb–1 Mb)") ap.add_argument("--long_cutoff", type=int, default=2_000_000, help="Long-range lower bound in bp (default 2 Mb; " "note: true compartment contacts are >10 Mb and " "are beyond the graph max_dist filter)") ap.add_argument("--n_perm", type=int, default=1000) ap.add_argument("--outdir", default="results/perturbation") ap.add_argument("--device", default="auto", help="cuda | cpu | auto (default: auto-detect CUDA)") args = ap.parse_args() if args.device == "auto": device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: device = torch.device(args.device) print(f"Using device: {device}" + (f" ({torch.cuda.get_device_name(0)})" if device.type == "cuda" else "")) os.makedirs(args.outdir, exist_ok=True) short_bins = args.short_cutoff // args.res # bins below = short-range long_bins = args.long_cutoff // args.res # bins above = long-range print(f"Short-range: edges < {args.short_cutoff//1000} kb " f"(|Δbin| < {short_bins})") print(f"Long-range: edges > {args.long_cutoff//1000} kb " f"(|Δbin| > {long_bins})") # ---- Load graphs ---- ctrl_data = torch.load(args.control_graph, weights_only=False) trt_data = torch.load(args.treated_graph, weights_only=False) for d, name in [(ctrl_data, "control"), (trt_data, "treated")]: ei, _ = remove_self_loops(d.edge_index) if hasattr(d, "edge_weight") and d.edge_weight is not None: ei, ew = to_undirected(ei, d.edge_weight, num_nodes=d.num_nodes) d.edge_weight = ew else: ei = to_undirected(ei, num_nodes=d.num_nodes) d.edge_weight = None d.edge_index = ei print(f"Graph {name}: {d.num_nodes} nodes, {ei.shape[1]} edges, " f"{d.x.shape[1]} node features") n_bins = ctrl_data.num_nodes # ---- Load model ---- state_dict = torch.load(args.model, map_location=device, weights_only=False) metrics = _load_metrics(args.model) encoder_name = metrics.get("encoder", "deep_gcn") hidden = metrics.get("hidden", 128) latent = metrics.get("latent", 64) heads = metrics.get("heads") or 4 in_dim = ctrl_data.x.shape[1] enc = build_encoder(encoder_name, in_dim=in_dim, hidden=hidden, latent=latent, dropout=0.0, heads=heads) model = VGAE(enc).to(device) model.load_state_dict(state_dict) model.eval() print(f"Loaded model: {encoder_name}, in={in_dim}, hidden={hidden}, latent={latent}") # Move graph tensors to device x_ctrl = ctrl_data.x.float().to(device) x_trt = trt_data.x.float().to(device) ctrl_data.edge_index = ctrl_data.edge_index.to(device) trt_data.edge_index = trt_data.edge_index.to(device) if ctrl_data.edge_weight is not None: ctrl_data.edge_weight = ctrl_data.edge_weight.to(device) if trt_data.edge_weight is not None: trt_data.edge_weight = trt_data.edge_weight.to(device) # ---- Step 5: Full embeddings ---- print("\n[Step 5] Encoding full graphs ...") ctrl_emb = _encode(model, x_ctrl, ctrl_data.edge_index, ctrl_data.edge_weight) trt_emb = _encode(model, x_trt, trt_data.edge_index, trt_data.edge_weight) np.save(os.path.join(args.outdir, "control_emb.npy"), ctrl_emb) np.save(os.path.join(args.outdir, "treated_emb.npy"), trt_emb) drift_full = _cosine_dist(ctrl_emb, trt_emb) np.save(os.path.join(args.outdir, "drift_full.npy"), drift_full) print(f"Full drift: mean={drift_full.mean():.4f}, " f"median={np.median(drift_full):.4f}, max={drift_full.max():.4f}") # ---- Step 6: Short-range and long-range drift ---- print("\n[Step 6] Short-range drift ...") ctrl_ei_sr, ctrl_ew_sr = _filter_edges_by_distance( ctrl_data.edge_index, ctrl_data.edge_weight, min_bins=0, max_bins=short_bins) trt_ei_sr, trt_ew_sr = _filter_edges_by_distance( trt_data.edge_index, trt_data.edge_weight, min_bins=0, max_bins=short_bins) print(f" Control short-range edges: {ctrl_ei_sr.shape[1]}") print(f" Treated short-range edges: {trt_ei_sr.shape[1]}") ctrl_emb_sr = _encode(model, x_ctrl, ctrl_ei_sr, ctrl_ew_sr) trt_emb_sr = _encode(model, x_trt, trt_ei_sr, trt_ew_sr) drift_short = _cosine_dist(ctrl_emb_sr, trt_emb_sr) np.save(os.path.join(args.outdir, "drift_short.npy"), drift_short) print(f" Short-range drift: mean={drift_short.mean():.4f}, " f"max={drift_short.max():.4f}") print("\n[Step 6] Long-range drift ...") ctrl_ei_lr, ctrl_ew_lr = _filter_edges_by_distance( ctrl_data.edge_index, ctrl_data.edge_weight, min_bins=long_bins, max_bins=int(1e9)) trt_ei_lr, trt_ew_lr = _filter_edges_by_distance( trt_data.edge_index, trt_data.edge_weight, min_bins=long_bins, max_bins=int(1e9)) print(f" Control long-range edges: {ctrl_ei_lr.shape[1]}") print(f" Treated long-range edges: {trt_ei_lr.shape[1]}") ctrl_emb_lr = _encode(model, x_ctrl, ctrl_ei_lr, ctrl_ew_lr) trt_emb_lr = _encode(model, x_trt, trt_ei_lr, trt_ew_lr) drift_long = _cosine_dist(ctrl_emb_lr, trt_emb_lr) np.save(os.path.join(args.outdir, "drift_long.npy"), drift_long) print(f" Long-range drift: mean={drift_long.mean():.4f}, " f"max={drift_long.max():.4f}") # ---- Loop anchor labeling ---- print("\n[Step 6] Defining loop anchor bins ...") anchor_mode = "ctcf" # default n_bedpe_loops = 0 if args.loops and os.path.exists(args.loops): import pandas as pd try: bedpe_df = pd.read_csv(args.loops, sep="\t") n_bedpe_loops = len(bedpe_df[bedpe_df["chrom1"] == args.chrom]) except Exception: n_bedpe_loops = 0 if n_bedpe_loops >= args.min_loops: anchor_mode = "bedpe" anchor_mask = _anchor_mask_from_bedpe( args.loops, args.chrom, n_bins, args.res) print(f" Mode: Hi-C loop anchors (bedpe), {n_bedpe_loops} loops") else: anchor_mode = "ctcf" if n_bedpe_loops > 0: print(f" NOTE: Only {n_bedpe_loops} loops in bedpe (< min_loops={args.min_loops}). " f"Likely low sequencing depth in G1/S-synchronized sample. " f"Falling back to CTCF ChIP-seq peak anchor proxy.") else: print(f" NOTE: No bedpe loops. Using CTCF ChIP-seq peak proxy for anchors.") anchor_mask = _anchor_mask_from_ctcf( ctrl_data.x.numpy(), ctcf_feat_idx=0, percentile=args.ctcf_percentile) print(f" Mode: CTCF signal proxy (percentile ≥ {args.ctcf_percentile})") np.save(os.path.join(args.outdir, "anchor_mask.npy"), anchor_mask) np.save(os.path.join(args.outdir, "anchor_mode.npy"), np.array([anchor_mode])) n_anchors = anchor_mask.sum() print(f" Loop anchor bins: {n_anchors} / {n_bins} " f"({100*n_anchors/n_bins:.1f}%)") # Rank bins by SHORT-range drift (this is the signal of interest: # loop-scale contacts are lost after RAD21 depletion) rank_order = np.argsort(drift_short)[::-1] top10_pct = rank_order[:max(1, n_bins//10)] overlap = anchor_mask[top10_pct].sum() print(f" Anchor bins in top-10% SHORT-range drift: {overlap} " f"/ {n_anchors} ({100*overlap/max(1,n_anchors):.1f}%)") # Also report long-range drift summary (expected to be low: compartments stable) print(f" Long-range drift mean (expected low, compartments preserved): " f"{drift_long.mean():.4f}") # ---- Step 7: Permutation test on SHORT-range drift at loop anchors ---- print(f"\n[Step 7] Permutation test ({args.n_perm} shuffles) ...") print(" Testing: is SHORT-range drift enriched at CTCF/loop-anchor bins?") obs_anchor, obs_non, p_val, null_stats = _permutation_pvalue( drift_short, anchor_mask, n_perm=args.n_perm) print(f" Mean short-range drift at anchors: {obs_anchor:.4f}") print(f" Mean short-range drift at non-anchors: {obs_non:.4f}") print(f" Empirical p-value: {p_val:.4f}") if p_val < 0.05: print(" RESULT: Short-range drift is significantly enriched at loop anchors " "(p < 0.05) — consistent with loop-scale contact loss at CTCF sites.") elif p_val < 0.1: print(" RESULT: Trend toward short-range enrichment at loop anchors " "(0.05 ≤ p < 0.10) — marginal evidence.") else: print(" RESULT: Short-range drift is NOT significantly enriched at loop anchors " f"(p = {p_val:.3f}) — null result. Report honestly.") # ---- Save stats ---- stats = { "chrom": args.chrom, "bin_size_bp": args.res, "short_cutoff_bp": args.short_cutoff, "long_cutoff_bp": args.long_cutoff, "n_bins": int(n_bins), "anchor_mode": anchor_mode, "n_bedpe_loops": int(n_bedpe_loops), "ctcf_percentile": args.ctcf_percentile if anchor_mode == "ctcf" else None, "n_anchor_bins": int(n_anchors), "drift_full_mean": float(drift_full.mean()), "drift_short_mean": float(drift_short.mean()), "drift_long_mean": float(drift_long.mean()), "signal_drift": "short_range", "obs_anchor_sr_mean": float(obs_anchor) if not np.isnan(obs_anchor) else None, "obs_non_sr_mean": float(obs_non) if not np.isnan(obs_non) else None, "drift_long_mean_note": "expected low (compartments preserved); " "caveat: graph max_dist=5Mb does not reach " "true compartment scale (>10Mb)", "perm_p_value": float(p_val) if not np.isnan(p_val) else None, "n_perm": args.n_perm, "top10pct_anchor_overlap_short_range": int(overlap), } with open(os.path.join(args.outdir, "drift_stats.json"), "w") as f: json.dump(stats, f, indent=2) np.save(os.path.join(args.outdir, "null_distribution.npy"), np.array(null_stats, dtype=float)) print(f"\nSaved → {args.outdir}/") if __name__ == "__main__": main()