420 lines
18 KiB
Python
420 lines
18 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
Perturbation drift analysis: Steps 5–7.
|
||
|
||
Encodes control and treated graphs with a trained VGAE, then decomposes
|
||
per-bin embedding drift into short-range and long-range components by
|
||
re-encoding with distance-filtered edge sets. Validates whether short-range
|
||
drift is enriched at loop anchor bins using a permutation test.
|
||
|
||
Biology of RAD21/cohesin depletion (Rao et al. 2017):
|
||
- Loops are lost: contacts at loop scale (~100 kb – 1 Mb)
|
||
- Compartments are preserved: contacts at chromosome scale (>10 Mb)
|
||
|
||
Therefore the expected drift pattern is:
|
||
- SHORT-RANGE drift HIGH at former loop anchors (loop contacts gone)
|
||
- LONG-RANGE drift LOW genome-wide (compartment contacts persist)
|
||
|
||
NOTE: the graph max_dist filter (default 5 Mb) means the "long-range" band
|
||
(2–5 Mb) is sub-compartment scale, not true compartment scale. We cannot
|
||
directly observe compartment preservation with this graph setup.
|
||
|
||
Inputs
|
||
------
|
||
control_graph.pt PyG Data object (control condition)
|
||
treated_graph.pt PyG Data object (treated/perturbed condition)
|
||
model.pt Trained VGAE state_dict (from train_vgae.py)
|
||
loops.bedpe Loop anchor coordinates (bedpe, from call_loops.py)
|
||
|
||
Outputs (under --outdir)
|
||
------------------------
|
||
control_emb.npy Full embeddings, control
|
||
treated_emb.npy Full embeddings, treated
|
||
drift_full.npy Per-bin cosine distance, full graph
|
||
drift_short.npy Per-bin cosine distance, edges < short_cutoff
|
||
drift_long.npy Per-bin cosine distance, edges > long_cutoff
|
||
anchor_mask.npy Bool array: True = loop anchor bin
|
||
drift_stats.json Summary statistics and permutation p-value
|
||
"""
|
||
|
||
import argparse
|
||
import json
|
||
import os
|
||
import sys
|
||
|
||
import numpy as np
|
||
import torch
|
||
from torch_geometric.nn.models import VGAE
|
||
from torch_geometric.utils import remove_self_loops, to_undirected
|
||
|
||
from chromatin_gnn.model import build_encoder
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Helpers
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def _load_metrics(model_path: str) -> dict:
|
||
p = os.path.join(os.path.dirname(os.path.abspath(model_path)), "metrics.json")
|
||
if os.path.exists(p):
|
||
with open(p) as f:
|
||
return json.load(f)
|
||
return {}
|
||
|
||
|
||
def _encode(model, x, edge_index, edge_weight):
|
||
model.eval()
|
||
with torch.no_grad():
|
||
z = model.encode(x, edge_index, edge_weight)
|
||
return z.cpu().numpy()
|
||
|
||
|
||
def _cosine_dist(a: np.ndarray, b: np.ndarray) -> np.ndarray:
|
||
"""Row-wise cosine distance between two [N, D] arrays."""
|
||
norm_a = np.linalg.norm(a, axis=1, keepdims=True) + 1e-8
|
||
norm_b = np.linalg.norm(b, axis=1, keepdims=True) + 1e-8
|
||
sim = (a / norm_a * (b / norm_b)).sum(axis=1)
|
||
return 1.0 - np.clip(sim, -1.0, 1.0)
|
||
|
||
|
||
def _filter_edges_by_distance(edge_index: torch.Tensor,
|
||
edge_weight: torch.Tensor | None,
|
||
min_bins: int = 0,
|
||
max_bins: int = int(1e9)):
|
||
"""Keep only edges whose bin-distance falls within [min_bins, max_bins)."""
|
||
src, dst = edge_index[0], edge_index[1]
|
||
dist = (src - dst).abs()
|
||
mask = (dist >= min_bins) & (dist < max_bins)
|
||
ei = edge_index[:, mask]
|
||
ew = edge_weight[mask] if edge_weight is not None else None
|
||
return ei, ew
|
||
|
||
|
||
def _anchor_mask_from_bedpe(bedpe_path: str, chrom: str,
|
||
n_bins: int, bin_size: int) -> np.ndarray:
|
||
"""Return bool array of length n_bins; True where bin overlaps a loop anchor."""
|
||
mask = np.zeros(n_bins, dtype=bool)
|
||
try:
|
||
import pandas as pd
|
||
df = pd.read_csv(bedpe_path, sep="\t")
|
||
# Use both anchor columns (chrom1/start1/end1 and chrom2/start2/end2)
|
||
for side in [("chrom1", "start1", "end1"), ("chrom2", "start2", "end2")]:
|
||
c, s, e = side
|
||
if c not in df.columns:
|
||
continue
|
||
sub = df[df[c] == chrom]
|
||
for _, row in sub.iterrows():
|
||
lo = int(row[s]) // bin_size
|
||
hi = int(row[e]) // bin_size + 1
|
||
lo = max(0, lo)
|
||
hi = min(n_bins, hi)
|
||
mask[lo:hi] = True
|
||
except Exception as ex:
|
||
print(f"WARNING: could not parse {bedpe_path}: {ex}")
|
||
return mask
|
||
|
||
|
||
def _anchor_mask_from_ctcf(x: np.ndarray, ctcf_feat_idx: int = 0,
|
||
percentile: float = 75.0) -> np.ndarray:
|
||
"""
|
||
Define loop anchor bins by CTCF ChIP-seq signal.
|
||
|
||
CTCF marks the genomic anchor points of RAD21/cohesin-dependent loops.
|
||
Bins above `percentile` of CTCF signal are treated as potential loop anchors.
|
||
|
||
This proxy is used when Hi-C loop calls lack statistical power (e.g. due
|
||
to low sequencing depth), and is independent of the contact data under test.
|
||
Reference: Rao et al. 2017 — "CTCF is required for maintaining the
|
||
architecture of cohesin-mediated chromatin loops."
|
||
"""
|
||
ctcf = x[:, ctcf_feat_idx]
|
||
threshold = np.percentile(ctcf[ctcf > 0], percentile) if (ctcf > 0).any() else 0.0
|
||
mask = ctcf >= threshold
|
||
return mask.astype(bool)
|
||
|
||
|
||
def _permutation_pvalue(drift_lr: np.ndarray,
|
||
anchor_mask: np.ndarray,
|
||
n_perm: int = 1000,
|
||
rng_seed: int = 42) -> tuple:
|
||
"""
|
||
One-sided label-permutation test.
|
||
|
||
Null: long-range drift at randomly chosen bins equals that at loop anchors.
|
||
Observed statistic: mean(drift_lr[anchor]) - mean(drift_lr[~anchor]).
|
||
p-value: fraction of permutations where the shuffled statistic ≥ observed.
|
||
"""
|
||
rng = np.random.default_rng(rng_seed)
|
||
n_anchors = anchor_mask.sum()
|
||
if n_anchors == 0:
|
||
return np.nan, np.nan, np.nan, []
|
||
|
||
obs_anchor = drift_lr[anchor_mask].mean()
|
||
obs_non = drift_lr[~anchor_mask].mean() if (~anchor_mask).any() else np.nan
|
||
obs_stat = obs_anchor - obs_non
|
||
|
||
null_stats = []
|
||
for _ in range(n_perm):
|
||
perm_idx = rng.choice(len(drift_lr), size=int(n_anchors), replace=False)
|
||
perm_anchor = drift_lr[perm_idx].mean()
|
||
perm_non = np.delete(drift_lr, perm_idx).mean()
|
||
null_stats.append(perm_anchor - perm_non)
|
||
|
||
null_arr = np.array(null_stats)
|
||
p_val = float((null_arr >= obs_stat).mean())
|
||
return float(obs_anchor), float(obs_non), p_val, null_stats
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Main
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def main():
|
||
ap = argparse.ArgumentParser(
|
||
description=__doc__,
|
||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||
)
|
||
ap.add_argument("--control_graph", required=True)
|
||
ap.add_argument("--treated_graph", required=True)
|
||
ap.add_argument("--model", required=True,
|
||
help="model.pt saved by train_vgae.py")
|
||
ap.add_argument("--loops", default=None,
|
||
help="Loop call bedpe (from call_loops.py). "
|
||
"If omitted or if fewer than --min_loops loops are found, "
|
||
"falls back to CTCF-signal anchor labelling.")
|
||
ap.add_argument("--min_loops", type=int, default=10,
|
||
help="Minimum bedpe loops required to use Hi-C anchor mode; "
|
||
"below this threshold CTCF proxy is used (default 10)")
|
||
ap.add_argument("--ctcf_percentile", type=float, default=75.0,
|
||
help="CTCF signal percentile threshold for anchor labelling "
|
||
"when falling back to CTCF proxy mode (default 75)")
|
||
ap.add_argument("--chrom", default="chr1")
|
||
ap.add_argument("--res", type=int, default=25_000,
|
||
help="Bin size in bp (must match graph resolution)")
|
||
ap.add_argument("--short_cutoff", type=int, default=1_000_000,
|
||
help="Short-range upper bound in bp (default 1 Mb; "
|
||
"captures full loop scale of ~100 kb–1 Mb)")
|
||
ap.add_argument("--long_cutoff", type=int, default=2_000_000,
|
||
help="Long-range lower bound in bp (default 2 Mb; "
|
||
"note: true compartment contacts are >10 Mb and "
|
||
"are beyond the graph max_dist filter)")
|
||
ap.add_argument("--n_perm", type=int, default=1000)
|
||
ap.add_argument("--outdir", default="results/perturbation")
|
||
ap.add_argument("--device", default="auto",
|
||
help="cuda | cpu | auto (default: auto-detect CUDA)")
|
||
args = ap.parse_args()
|
||
|
||
if args.device == "auto":
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
else:
|
||
device = torch.device(args.device)
|
||
print(f"Using device: {device}"
|
||
+ (f" ({torch.cuda.get_device_name(0)})" if device.type == "cuda" else ""))
|
||
|
||
os.makedirs(args.outdir, exist_ok=True)
|
||
|
||
short_bins = args.short_cutoff // args.res # bins below = short-range
|
||
long_bins = args.long_cutoff // args.res # bins above = long-range
|
||
|
||
print(f"Short-range: edges < {args.short_cutoff//1000} kb "
|
||
f"(|Δbin| < {short_bins})")
|
||
print(f"Long-range: edges > {args.long_cutoff//1000} kb "
|
||
f"(|Δbin| > {long_bins})")
|
||
|
||
# ---- Load graphs ----
|
||
ctrl_data = torch.load(args.control_graph, weights_only=False)
|
||
trt_data = torch.load(args.treated_graph, weights_only=False)
|
||
|
||
for d, name in [(ctrl_data, "control"), (trt_data, "treated")]:
|
||
ei, _ = remove_self_loops(d.edge_index)
|
||
if hasattr(d, "edge_weight") and d.edge_weight is not None:
|
||
ei, ew = to_undirected(ei, d.edge_weight, num_nodes=d.num_nodes)
|
||
d.edge_weight = ew
|
||
else:
|
||
ei = to_undirected(ei, num_nodes=d.num_nodes)
|
||
d.edge_weight = None
|
||
d.edge_index = ei
|
||
print(f"Graph {name}: {d.num_nodes} nodes, {ei.shape[1]} edges, "
|
||
f"{d.x.shape[1]} node features")
|
||
|
||
n_bins = ctrl_data.num_nodes
|
||
|
||
# ---- Load model ----
|
||
state_dict = torch.load(args.model, map_location=device, weights_only=False)
|
||
metrics = _load_metrics(args.model)
|
||
encoder_name = metrics.get("encoder", "deep_gcn")
|
||
hidden = metrics.get("hidden", 128)
|
||
latent = metrics.get("latent", 64)
|
||
heads = metrics.get("heads") or 4
|
||
in_dim = ctrl_data.x.shape[1]
|
||
|
||
enc = build_encoder(encoder_name, in_dim=in_dim, hidden=hidden,
|
||
latent=latent, dropout=0.0, heads=heads)
|
||
model = VGAE(enc).to(device)
|
||
model.load_state_dict(state_dict)
|
||
model.eval()
|
||
print(f"Loaded model: {encoder_name}, in={in_dim}, hidden={hidden}, latent={latent}")
|
||
|
||
# Move graph tensors to device
|
||
x_ctrl = ctrl_data.x.float().to(device)
|
||
x_trt = trt_data.x.float().to(device)
|
||
ctrl_data.edge_index = ctrl_data.edge_index.to(device)
|
||
trt_data.edge_index = trt_data.edge_index.to(device)
|
||
if ctrl_data.edge_weight is not None:
|
||
ctrl_data.edge_weight = ctrl_data.edge_weight.to(device)
|
||
if trt_data.edge_weight is not None:
|
||
trt_data.edge_weight = trt_data.edge_weight.to(device)
|
||
|
||
# ---- Step 5: Full embeddings ----
|
||
print("\n[Step 5] Encoding full graphs ...")
|
||
ctrl_emb = _encode(model, x_ctrl, ctrl_data.edge_index, ctrl_data.edge_weight)
|
||
trt_emb = _encode(model, x_trt, trt_data.edge_index, trt_data.edge_weight)
|
||
np.save(os.path.join(args.outdir, "control_emb.npy"), ctrl_emb)
|
||
np.save(os.path.join(args.outdir, "treated_emb.npy"), trt_emb)
|
||
|
||
drift_full = _cosine_dist(ctrl_emb, trt_emb)
|
||
np.save(os.path.join(args.outdir, "drift_full.npy"), drift_full)
|
||
print(f"Full drift: mean={drift_full.mean():.4f}, "
|
||
f"median={np.median(drift_full):.4f}, max={drift_full.max():.4f}")
|
||
|
||
# ---- Step 6: Short-range and long-range drift ----
|
||
print("\n[Step 6] Short-range drift ...")
|
||
ctrl_ei_sr, ctrl_ew_sr = _filter_edges_by_distance(
|
||
ctrl_data.edge_index, ctrl_data.edge_weight,
|
||
min_bins=0, max_bins=short_bins)
|
||
trt_ei_sr, trt_ew_sr = _filter_edges_by_distance(
|
||
trt_data.edge_index, trt_data.edge_weight,
|
||
min_bins=0, max_bins=short_bins)
|
||
print(f" Control short-range edges: {ctrl_ei_sr.shape[1]}")
|
||
print(f" Treated short-range edges: {trt_ei_sr.shape[1]}")
|
||
|
||
ctrl_emb_sr = _encode(model, x_ctrl, ctrl_ei_sr, ctrl_ew_sr)
|
||
trt_emb_sr = _encode(model, x_trt, trt_ei_sr, trt_ew_sr)
|
||
drift_short = _cosine_dist(ctrl_emb_sr, trt_emb_sr)
|
||
np.save(os.path.join(args.outdir, "drift_short.npy"), drift_short)
|
||
print(f" Short-range drift: mean={drift_short.mean():.4f}, "
|
||
f"max={drift_short.max():.4f}")
|
||
|
||
print("\n[Step 6] Long-range drift ...")
|
||
ctrl_ei_lr, ctrl_ew_lr = _filter_edges_by_distance(
|
||
ctrl_data.edge_index, ctrl_data.edge_weight,
|
||
min_bins=long_bins, max_bins=int(1e9))
|
||
trt_ei_lr, trt_ew_lr = _filter_edges_by_distance(
|
||
trt_data.edge_index, trt_data.edge_weight,
|
||
min_bins=long_bins, max_bins=int(1e9))
|
||
print(f" Control long-range edges: {ctrl_ei_lr.shape[1]}")
|
||
print(f" Treated long-range edges: {trt_ei_lr.shape[1]}")
|
||
|
||
ctrl_emb_lr = _encode(model, x_ctrl, ctrl_ei_lr, ctrl_ew_lr)
|
||
trt_emb_lr = _encode(model, x_trt, trt_ei_lr, trt_ew_lr)
|
||
drift_long = _cosine_dist(ctrl_emb_lr, trt_emb_lr)
|
||
np.save(os.path.join(args.outdir, "drift_long.npy"), drift_long)
|
||
print(f" Long-range drift: mean={drift_long.mean():.4f}, "
|
||
f"max={drift_long.max():.4f}")
|
||
|
||
# ---- Loop anchor labeling ----
|
||
print("\n[Step 6] Defining loop anchor bins ...")
|
||
anchor_mode = "ctcf" # default
|
||
n_bedpe_loops = 0
|
||
|
||
if args.loops and os.path.exists(args.loops):
|
||
import pandas as pd
|
||
try:
|
||
bedpe_df = pd.read_csv(args.loops, sep="\t")
|
||
n_bedpe_loops = len(bedpe_df[bedpe_df["chrom1"] == args.chrom])
|
||
except Exception:
|
||
n_bedpe_loops = 0
|
||
|
||
if n_bedpe_loops >= args.min_loops:
|
||
anchor_mode = "bedpe"
|
||
anchor_mask = _anchor_mask_from_bedpe(
|
||
args.loops, args.chrom, n_bins, args.res)
|
||
print(f" Mode: Hi-C loop anchors (bedpe), {n_bedpe_loops} loops")
|
||
else:
|
||
anchor_mode = "ctcf"
|
||
if n_bedpe_loops > 0:
|
||
print(f" NOTE: Only {n_bedpe_loops} loops in bedpe (< min_loops={args.min_loops}). "
|
||
f"Likely low sequencing depth in G1/S-synchronized sample. "
|
||
f"Falling back to CTCF ChIP-seq peak anchor proxy.")
|
||
else:
|
||
print(f" NOTE: No bedpe loops. Using CTCF ChIP-seq peak proxy for anchors.")
|
||
anchor_mask = _anchor_mask_from_ctcf(
|
||
ctrl_data.x.numpy(), ctcf_feat_idx=0,
|
||
percentile=args.ctcf_percentile)
|
||
print(f" Mode: CTCF signal proxy (percentile ≥ {args.ctcf_percentile})")
|
||
|
||
np.save(os.path.join(args.outdir, "anchor_mask.npy"), anchor_mask)
|
||
np.save(os.path.join(args.outdir, "anchor_mode.npy"),
|
||
np.array([anchor_mode]))
|
||
n_anchors = anchor_mask.sum()
|
||
print(f" Loop anchor bins: {n_anchors} / {n_bins} "
|
||
f"({100*n_anchors/n_bins:.1f}%)")
|
||
|
||
# Rank bins by SHORT-range drift (this is the signal of interest:
|
||
# loop-scale contacts are lost after RAD21 depletion)
|
||
rank_order = np.argsort(drift_short)[::-1]
|
||
top10_pct = rank_order[:max(1, n_bins//10)]
|
||
overlap = anchor_mask[top10_pct].sum()
|
||
print(f" Anchor bins in top-10% SHORT-range drift: {overlap} "
|
||
f"/ {n_anchors} ({100*overlap/max(1,n_anchors):.1f}%)")
|
||
|
||
# Also report long-range drift summary (expected to be low: compartments stable)
|
||
print(f" Long-range drift mean (expected low, compartments preserved): "
|
||
f"{drift_long.mean():.4f}")
|
||
|
||
# ---- Step 7: Permutation test on SHORT-range drift at loop anchors ----
|
||
print(f"\n[Step 7] Permutation test ({args.n_perm} shuffles) ...")
|
||
print(" Testing: is SHORT-range drift enriched at CTCF/loop-anchor bins?")
|
||
obs_anchor, obs_non, p_val, null_stats = _permutation_pvalue(
|
||
drift_short, anchor_mask, n_perm=args.n_perm)
|
||
|
||
print(f" Mean short-range drift at anchors: {obs_anchor:.4f}")
|
||
print(f" Mean short-range drift at non-anchors: {obs_non:.4f}")
|
||
print(f" Empirical p-value: {p_val:.4f}")
|
||
|
||
if p_val < 0.05:
|
||
print(" RESULT: Short-range drift is significantly enriched at loop anchors "
|
||
"(p < 0.05) — consistent with loop-scale contact loss at CTCF sites.")
|
||
elif p_val < 0.1:
|
||
print(" RESULT: Trend toward short-range enrichment at loop anchors "
|
||
"(0.05 ≤ p < 0.10) — marginal evidence.")
|
||
else:
|
||
print(" RESULT: Short-range drift is NOT significantly enriched at loop anchors "
|
||
f"(p = {p_val:.3f}) — null result. Report honestly.")
|
||
|
||
# ---- Save stats ----
|
||
stats = {
|
||
"chrom": args.chrom,
|
||
"bin_size_bp": args.res,
|
||
"short_cutoff_bp": args.short_cutoff,
|
||
"long_cutoff_bp": args.long_cutoff,
|
||
"n_bins": int(n_bins),
|
||
"anchor_mode": anchor_mode,
|
||
"n_bedpe_loops": int(n_bedpe_loops),
|
||
"ctcf_percentile": args.ctcf_percentile if anchor_mode == "ctcf" else None,
|
||
"n_anchor_bins": int(n_anchors),
|
||
"drift_full_mean": float(drift_full.mean()),
|
||
"drift_short_mean": float(drift_short.mean()),
|
||
"drift_long_mean": float(drift_long.mean()),
|
||
"signal_drift": "short_range",
|
||
"obs_anchor_sr_mean": float(obs_anchor) if not np.isnan(obs_anchor) else None,
|
||
"obs_non_sr_mean": float(obs_non) if not np.isnan(obs_non) else None,
|
||
"drift_long_mean_note": "expected low (compartments preserved); "
|
||
"caveat: graph max_dist=5Mb does not reach "
|
||
"true compartment scale (>10Mb)",
|
||
"perm_p_value": float(p_val) if not np.isnan(p_val) else None,
|
||
"n_perm": args.n_perm,
|
||
"top10pct_anchor_overlap_short_range": int(overlap),
|
||
}
|
||
with open(os.path.join(args.outdir, "drift_stats.json"), "w") as f:
|
||
json.dump(stats, f, indent=2)
|
||
|
||
np.save(os.path.join(args.outdir, "null_distribution.npy"),
|
||
np.array(null_stats, dtype=float))
|
||
|
||
print(f"\nSaved → {args.outdir}/")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|