Files
chromatin-vgae-hic/experiments/h2_rewiring/perturbation_analysis.py

420 lines
18 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
Perturbation drift analysis: Steps 57.
Encodes control and treated graphs with a trained VGAE, then decomposes
per-bin embedding drift into short-range and long-range components by
re-encoding with distance-filtered edge sets. Validates whether short-range
drift is enriched at loop anchor bins using a permutation test.
Biology of RAD21/cohesin depletion (Rao et al. 2017):
- Loops are lost: contacts at loop scale (~100 kb 1 Mb)
- Compartments are preserved: contacts at chromosome scale (>10 Mb)
Therefore the expected drift pattern is:
- SHORT-RANGE drift HIGH at former loop anchors (loop contacts gone)
- LONG-RANGE drift LOW genome-wide (compartment contacts persist)
NOTE: the graph max_dist filter (default 5 Mb) means the "long-range" band
(25 Mb) is sub-compartment scale, not true compartment scale. We cannot
directly observe compartment preservation with this graph setup.
Inputs
------
control_graph.pt PyG Data object (control condition)
treated_graph.pt PyG Data object (treated/perturbed condition)
model.pt Trained VGAE state_dict (from train_vgae.py)
loops.bedpe Loop anchor coordinates (bedpe, from call_loops.py)
Outputs (under --outdir)
------------------------
control_emb.npy Full embeddings, control
treated_emb.npy Full embeddings, treated
drift_full.npy Per-bin cosine distance, full graph
drift_short.npy Per-bin cosine distance, edges < short_cutoff
drift_long.npy Per-bin cosine distance, edges > long_cutoff
anchor_mask.npy Bool array: True = loop anchor bin
drift_stats.json Summary statistics and permutation p-value
"""
import argparse
import json
import os
import sys
import numpy as np
import torch
from torch_geometric.nn.models import VGAE
from torch_geometric.utils import remove_self_loops, to_undirected
from chromatin_gnn.model import build_encoder
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _load_metrics(model_path: str) -> dict:
p = os.path.join(os.path.dirname(os.path.abspath(model_path)), "metrics.json")
if os.path.exists(p):
with open(p) as f:
return json.load(f)
return {}
def _encode(model, x, edge_index, edge_weight):
model.eval()
with torch.no_grad():
z = model.encode(x, edge_index, edge_weight)
return z.cpu().numpy()
def _cosine_dist(a: np.ndarray, b: np.ndarray) -> np.ndarray:
"""Row-wise cosine distance between two [N, D] arrays."""
norm_a = np.linalg.norm(a, axis=1, keepdims=True) + 1e-8
norm_b = np.linalg.norm(b, axis=1, keepdims=True) + 1e-8
sim = (a / norm_a * (b / norm_b)).sum(axis=1)
return 1.0 - np.clip(sim, -1.0, 1.0)
def _filter_edges_by_distance(edge_index: torch.Tensor,
edge_weight: torch.Tensor | None,
min_bins: int = 0,
max_bins: int = int(1e9)):
"""Keep only edges whose bin-distance falls within [min_bins, max_bins)."""
src, dst = edge_index[0], edge_index[1]
dist = (src - dst).abs()
mask = (dist >= min_bins) & (dist < max_bins)
ei = edge_index[:, mask]
ew = edge_weight[mask] if edge_weight is not None else None
return ei, ew
def _anchor_mask_from_bedpe(bedpe_path: str, chrom: str,
n_bins: int, bin_size: int) -> np.ndarray:
"""Return bool array of length n_bins; True where bin overlaps a loop anchor."""
mask = np.zeros(n_bins, dtype=bool)
try:
import pandas as pd
df = pd.read_csv(bedpe_path, sep="\t")
# Use both anchor columns (chrom1/start1/end1 and chrom2/start2/end2)
for side in [("chrom1", "start1", "end1"), ("chrom2", "start2", "end2")]:
c, s, e = side
if c not in df.columns:
continue
sub = df[df[c] == chrom]
for _, row in sub.iterrows():
lo = int(row[s]) // bin_size
hi = int(row[e]) // bin_size + 1
lo = max(0, lo)
hi = min(n_bins, hi)
mask[lo:hi] = True
except Exception as ex:
print(f"WARNING: could not parse {bedpe_path}: {ex}")
return mask
def _anchor_mask_from_ctcf(x: np.ndarray, ctcf_feat_idx: int = 0,
percentile: float = 75.0) -> np.ndarray:
"""
Define loop anchor bins by CTCF ChIP-seq signal.
CTCF marks the genomic anchor points of RAD21/cohesin-dependent loops.
Bins above `percentile` of CTCF signal are treated as potential loop anchors.
This proxy is used when Hi-C loop calls lack statistical power (e.g. due
to low sequencing depth), and is independent of the contact data under test.
Reference: Rao et al. 2017 — "CTCF is required for maintaining the
architecture of cohesin-mediated chromatin loops."
"""
ctcf = x[:, ctcf_feat_idx]
threshold = np.percentile(ctcf[ctcf > 0], percentile) if (ctcf > 0).any() else 0.0
mask = ctcf >= threshold
return mask.astype(bool)
def _permutation_pvalue(drift_lr: np.ndarray,
anchor_mask: np.ndarray,
n_perm: int = 1000,
rng_seed: int = 42) -> tuple:
"""
One-sided label-permutation test.
Null: long-range drift at randomly chosen bins equals that at loop anchors.
Observed statistic: mean(drift_lr[anchor]) - mean(drift_lr[~anchor]).
p-value: fraction of permutations where the shuffled statistic ≥ observed.
"""
rng = np.random.default_rng(rng_seed)
n_anchors = anchor_mask.sum()
if n_anchors == 0:
return np.nan, np.nan, np.nan, []
obs_anchor = drift_lr[anchor_mask].mean()
obs_non = drift_lr[~anchor_mask].mean() if (~anchor_mask).any() else np.nan
obs_stat = obs_anchor - obs_non
null_stats = []
for _ in range(n_perm):
perm_idx = rng.choice(len(drift_lr), size=int(n_anchors), replace=False)
perm_anchor = drift_lr[perm_idx].mean()
perm_non = np.delete(drift_lr, perm_idx).mean()
null_stats.append(perm_anchor - perm_non)
null_arr = np.array(null_stats)
p_val = float((null_arr >= obs_stat).mean())
return float(obs_anchor), float(obs_non), p_val, null_stats
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
ap = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
ap.add_argument("--control_graph", required=True)
ap.add_argument("--treated_graph", required=True)
ap.add_argument("--model", required=True,
help="model.pt saved by train_vgae.py")
ap.add_argument("--loops", default=None,
help="Loop call bedpe (from call_loops.py). "
"If omitted or if fewer than --min_loops loops are found, "
"falls back to CTCF-signal anchor labelling.")
ap.add_argument("--min_loops", type=int, default=10,
help="Minimum bedpe loops required to use Hi-C anchor mode; "
"below this threshold CTCF proxy is used (default 10)")
ap.add_argument("--ctcf_percentile", type=float, default=75.0,
help="CTCF signal percentile threshold for anchor labelling "
"when falling back to CTCF proxy mode (default 75)")
ap.add_argument("--chrom", default="chr1")
ap.add_argument("--res", type=int, default=25_000,
help="Bin size in bp (must match graph resolution)")
ap.add_argument("--short_cutoff", type=int, default=1_000_000,
help="Short-range upper bound in bp (default 1 Mb; "
"captures full loop scale of ~100 kb1 Mb)")
ap.add_argument("--long_cutoff", type=int, default=2_000_000,
help="Long-range lower bound in bp (default 2 Mb; "
"note: true compartment contacts are >10 Mb and "
"are beyond the graph max_dist filter)")
ap.add_argument("--n_perm", type=int, default=1000)
ap.add_argument("--outdir", default="results/perturbation")
ap.add_argument("--device", default="auto",
help="cuda | cpu | auto (default: auto-detect CUDA)")
args = ap.parse_args()
if args.device == "auto":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
device = torch.device(args.device)
print(f"Using device: {device}"
+ (f" ({torch.cuda.get_device_name(0)})" if device.type == "cuda" else ""))
os.makedirs(args.outdir, exist_ok=True)
short_bins = args.short_cutoff // args.res # bins below = short-range
long_bins = args.long_cutoff // args.res # bins above = long-range
print(f"Short-range: edges < {args.short_cutoff//1000} kb "
f"(|Δbin| < {short_bins})")
print(f"Long-range: edges > {args.long_cutoff//1000} kb "
f"(|Δbin| > {long_bins})")
# ---- Load graphs ----
ctrl_data = torch.load(args.control_graph, weights_only=False)
trt_data = torch.load(args.treated_graph, weights_only=False)
for d, name in [(ctrl_data, "control"), (trt_data, "treated")]:
ei, _ = remove_self_loops(d.edge_index)
if hasattr(d, "edge_weight") and d.edge_weight is not None:
ei, ew = to_undirected(ei, d.edge_weight, num_nodes=d.num_nodes)
d.edge_weight = ew
else:
ei = to_undirected(ei, num_nodes=d.num_nodes)
d.edge_weight = None
d.edge_index = ei
print(f"Graph {name}: {d.num_nodes} nodes, {ei.shape[1]} edges, "
f"{d.x.shape[1]} node features")
n_bins = ctrl_data.num_nodes
# ---- Load model ----
state_dict = torch.load(args.model, map_location=device, weights_only=False)
metrics = _load_metrics(args.model)
encoder_name = metrics.get("encoder", "deep_gcn")
hidden = metrics.get("hidden", 128)
latent = metrics.get("latent", 64)
heads = metrics.get("heads") or 4
in_dim = ctrl_data.x.shape[1]
enc = build_encoder(encoder_name, in_dim=in_dim, hidden=hidden,
latent=latent, dropout=0.0, heads=heads)
model = VGAE(enc).to(device)
model.load_state_dict(state_dict)
model.eval()
print(f"Loaded model: {encoder_name}, in={in_dim}, hidden={hidden}, latent={latent}")
# Move graph tensors to device
x_ctrl = ctrl_data.x.float().to(device)
x_trt = trt_data.x.float().to(device)
ctrl_data.edge_index = ctrl_data.edge_index.to(device)
trt_data.edge_index = trt_data.edge_index.to(device)
if ctrl_data.edge_weight is not None:
ctrl_data.edge_weight = ctrl_data.edge_weight.to(device)
if trt_data.edge_weight is not None:
trt_data.edge_weight = trt_data.edge_weight.to(device)
# ---- Step 5: Full embeddings ----
print("\n[Step 5] Encoding full graphs ...")
ctrl_emb = _encode(model, x_ctrl, ctrl_data.edge_index, ctrl_data.edge_weight)
trt_emb = _encode(model, x_trt, trt_data.edge_index, trt_data.edge_weight)
np.save(os.path.join(args.outdir, "control_emb.npy"), ctrl_emb)
np.save(os.path.join(args.outdir, "treated_emb.npy"), trt_emb)
drift_full = _cosine_dist(ctrl_emb, trt_emb)
np.save(os.path.join(args.outdir, "drift_full.npy"), drift_full)
print(f"Full drift: mean={drift_full.mean():.4f}, "
f"median={np.median(drift_full):.4f}, max={drift_full.max():.4f}")
# ---- Step 6: Short-range and long-range drift ----
print("\n[Step 6] Short-range drift ...")
ctrl_ei_sr, ctrl_ew_sr = _filter_edges_by_distance(
ctrl_data.edge_index, ctrl_data.edge_weight,
min_bins=0, max_bins=short_bins)
trt_ei_sr, trt_ew_sr = _filter_edges_by_distance(
trt_data.edge_index, trt_data.edge_weight,
min_bins=0, max_bins=short_bins)
print(f" Control short-range edges: {ctrl_ei_sr.shape[1]}")
print(f" Treated short-range edges: {trt_ei_sr.shape[1]}")
ctrl_emb_sr = _encode(model, x_ctrl, ctrl_ei_sr, ctrl_ew_sr)
trt_emb_sr = _encode(model, x_trt, trt_ei_sr, trt_ew_sr)
drift_short = _cosine_dist(ctrl_emb_sr, trt_emb_sr)
np.save(os.path.join(args.outdir, "drift_short.npy"), drift_short)
print(f" Short-range drift: mean={drift_short.mean():.4f}, "
f"max={drift_short.max():.4f}")
print("\n[Step 6] Long-range drift ...")
ctrl_ei_lr, ctrl_ew_lr = _filter_edges_by_distance(
ctrl_data.edge_index, ctrl_data.edge_weight,
min_bins=long_bins, max_bins=int(1e9))
trt_ei_lr, trt_ew_lr = _filter_edges_by_distance(
trt_data.edge_index, trt_data.edge_weight,
min_bins=long_bins, max_bins=int(1e9))
print(f" Control long-range edges: {ctrl_ei_lr.shape[1]}")
print(f" Treated long-range edges: {trt_ei_lr.shape[1]}")
ctrl_emb_lr = _encode(model, x_ctrl, ctrl_ei_lr, ctrl_ew_lr)
trt_emb_lr = _encode(model, x_trt, trt_ei_lr, trt_ew_lr)
drift_long = _cosine_dist(ctrl_emb_lr, trt_emb_lr)
np.save(os.path.join(args.outdir, "drift_long.npy"), drift_long)
print(f" Long-range drift: mean={drift_long.mean():.4f}, "
f"max={drift_long.max():.4f}")
# ---- Loop anchor labeling ----
print("\n[Step 6] Defining loop anchor bins ...")
anchor_mode = "ctcf" # default
n_bedpe_loops = 0
if args.loops and os.path.exists(args.loops):
import pandas as pd
try:
bedpe_df = pd.read_csv(args.loops, sep="\t")
n_bedpe_loops = len(bedpe_df[bedpe_df["chrom1"] == args.chrom])
except Exception:
n_bedpe_loops = 0
if n_bedpe_loops >= args.min_loops:
anchor_mode = "bedpe"
anchor_mask = _anchor_mask_from_bedpe(
args.loops, args.chrom, n_bins, args.res)
print(f" Mode: Hi-C loop anchors (bedpe), {n_bedpe_loops} loops")
else:
anchor_mode = "ctcf"
if n_bedpe_loops > 0:
print(f" NOTE: Only {n_bedpe_loops} loops in bedpe (< min_loops={args.min_loops}). "
f"Likely low sequencing depth in G1/S-synchronized sample. "
f"Falling back to CTCF ChIP-seq peak anchor proxy.")
else:
print(f" NOTE: No bedpe loops. Using CTCF ChIP-seq peak proxy for anchors.")
anchor_mask = _anchor_mask_from_ctcf(
ctrl_data.x.numpy(), ctcf_feat_idx=0,
percentile=args.ctcf_percentile)
print(f" Mode: CTCF signal proxy (percentile ≥ {args.ctcf_percentile})")
np.save(os.path.join(args.outdir, "anchor_mask.npy"), anchor_mask)
np.save(os.path.join(args.outdir, "anchor_mode.npy"),
np.array([anchor_mode]))
n_anchors = anchor_mask.sum()
print(f" Loop anchor bins: {n_anchors} / {n_bins} "
f"({100*n_anchors/n_bins:.1f}%)")
# Rank bins by SHORT-range drift (this is the signal of interest:
# loop-scale contacts are lost after RAD21 depletion)
rank_order = np.argsort(drift_short)[::-1]
top10_pct = rank_order[:max(1, n_bins//10)]
overlap = anchor_mask[top10_pct].sum()
print(f" Anchor bins in top-10% SHORT-range drift: {overlap} "
f"/ {n_anchors} ({100*overlap/max(1,n_anchors):.1f}%)")
# Also report long-range drift summary (expected to be low: compartments stable)
print(f" Long-range drift mean (expected low, compartments preserved): "
f"{drift_long.mean():.4f}")
# ---- Step 7: Permutation test on SHORT-range drift at loop anchors ----
print(f"\n[Step 7] Permutation test ({args.n_perm} shuffles) ...")
print(" Testing: is SHORT-range drift enriched at CTCF/loop-anchor bins?")
obs_anchor, obs_non, p_val, null_stats = _permutation_pvalue(
drift_short, anchor_mask, n_perm=args.n_perm)
print(f" Mean short-range drift at anchors: {obs_anchor:.4f}")
print(f" Mean short-range drift at non-anchors: {obs_non:.4f}")
print(f" Empirical p-value: {p_val:.4f}")
if p_val < 0.05:
print(" RESULT: Short-range drift is significantly enriched at loop anchors "
"(p < 0.05) — consistent with loop-scale contact loss at CTCF sites.")
elif p_val < 0.1:
print(" RESULT: Trend toward short-range enrichment at loop anchors "
"(0.05 ≤ p < 0.10) — marginal evidence.")
else:
print(" RESULT: Short-range drift is NOT significantly enriched at loop anchors "
f"(p = {p_val:.3f}) — null result. Report honestly.")
# ---- Save stats ----
stats = {
"chrom": args.chrom,
"bin_size_bp": args.res,
"short_cutoff_bp": args.short_cutoff,
"long_cutoff_bp": args.long_cutoff,
"n_bins": int(n_bins),
"anchor_mode": anchor_mode,
"n_bedpe_loops": int(n_bedpe_loops),
"ctcf_percentile": args.ctcf_percentile if anchor_mode == "ctcf" else None,
"n_anchor_bins": int(n_anchors),
"drift_full_mean": float(drift_full.mean()),
"drift_short_mean": float(drift_short.mean()),
"drift_long_mean": float(drift_long.mean()),
"signal_drift": "short_range",
"obs_anchor_sr_mean": float(obs_anchor) if not np.isnan(obs_anchor) else None,
"obs_non_sr_mean": float(obs_non) if not np.isnan(obs_non) else None,
"drift_long_mean_note": "expected low (compartments preserved); "
"caveat: graph max_dist=5Mb does not reach "
"true compartment scale (>10Mb)",
"perm_p_value": float(p_val) if not np.isnan(p_val) else None,
"n_perm": args.n_perm,
"top10pct_anchor_overlap_short_range": int(overlap),
}
with open(os.path.join(args.outdir, "drift_stats.json"), "w") as f:
json.dump(stats, f, indent=2)
np.save(os.path.join(args.outdir, "null_distribution.npy"),
np.array(null_stats, dtype=float))
print(f"\nSaved → {args.outdir}/")
if __name__ == "__main__":
main()