Files
chromatin-vgae-hic/experiments/h3_longrange/evaluate_cross_ablation.py

183 lines
7.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
H3 feature × edge cross-ablation evaluation.
Builds a 2×3 grid:
full edges local (<250 kb) long-range (>1 Mb)
real features A B C
constant features D E F
AC come from the existing H1/H3 trainings (real features × edge subset).
DF are new trainings with `--constant_features` on the same edge subsets.
For each cell we report:
• Link prediction test AUC/AP (from cell's metrics.json)
• Compartment classification AUC/AP (5-fold CV LogReg on embeddings)
• Spearman r(PC1, best latent dim)
The cross-ablation separates two confounded sources of signal in H3:
• Feature signal: epigenetic marks themselves correlate with compartment;
message-passing smooths them over adjacent bins regardless of topology.
• Topology signal: information genuinely carried by graph structure.
Reading the table:
• Δ (real const) at fixed edges = contribution of features given that topology.
• Δ across edges at fixed features = contribution of edge-band topology.
• The bottom row alone shows what each edge-band carries from topology alone.
Usage
-----
python experiments/h3_longrange/evaluate_cross_ablation.py \\
--full_real_dir results/h1_representation \\
--local_real_dir results/h3_longrange/local_only \\
--longrange_real_dir results/h3_longrange/longrange_only \\
--full_const_dir results/h3_longrange/full_const \\
--local_const_dir results/h3_longrange/local_const \\
--longrange_const_dir results/h3_longrange/longrange_const \\
--compartments results/h1_representation/compartments/gm12878_chr1.csv \\
--out results/h3_longrange/cross_ablation.json
"""
import argparse
import json
import os
import numpy as np
import pandas as pd
from scipy.stats import spearmanr
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score, average_precision_score
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import StandardScaler
def cv_auc(X, y, n_splits=5, seed=42):
skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=seed)
aucs, aps = [], []
for tr, te in skf.split(X, y):
sc = StandardScaler().fit(X[tr])
lr = LogisticRegression(max_iter=1000, random_state=seed)
lr.fit(sc.transform(X[tr]), y[tr])
prob = lr.predict_proba(sc.transform(X[te]))[:, 1]
aucs.append(roc_auc_score(y[te], prob))
aps.append(average_precision_score(y[te], prob))
return float(np.mean(aucs)), float(np.mean(aps))
def best_spearman(emb, pc1):
rs = [abs(spearmanr(emb[:, d], pc1).statistic) for d in range(emb.shape[1])]
return float(np.max(rs)), int(np.argmax(rs))
def _emb_name(d):
# H1 uses gm12878_emb.npy; H3 cells use emb.npy
for name in ("emb.npy", "gm12878_emb.npy"):
p = os.path.join(d, name)
if os.path.exists(p):
return p
raise FileNotFoundError(f"No emb file in {d}")
def evaluate_cell(emb_dir, pc1, y, mask):
emb = np.load(_emb_name(emb_dir))
with open(os.path.join(emb_dir, "metrics.json")) as f:
m = json.load(f)
emb_m = emb[mask]
auc, ap = cv_auc(emb_m, y)
r, dim = best_spearman(emb_m, pc1)
return {
"link_pred_test_auc": m.get("test_auc"),
"link_pred_test_ap": m.get("test_ap"),
"compartment_auc_5fold": auc,
"compartment_ap_5fold": ap,
"spearman_r_best": r,
"spearman_best_dim": dim,
"epochs_ran": m.get("epochs_ran"),
"hidden": m.get("hidden"),
"latent": m.get("latent"),
"constant_features": m.get("constant_features", False),
"in_features": m.get("in_features"),
}
def main():
ap = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter)
ap.add_argument("--full_real_dir", required=True)
ap.add_argument("--local_real_dir", required=True)
ap.add_argument("--longrange_real_dir", required=True)
ap.add_argument("--full_const_dir", required=True)
ap.add_argument("--local_const_dir", required=True)
ap.add_argument("--longrange_const_dir", required=True)
ap.add_argument("--compartments", required=True)
ap.add_argument("--out", required=True)
args = ap.parse_args()
df = pd.read_csv(args.compartments)
valid = df["compartment"].isin(["A", "B"]) & df["pc1"].notna()
mask = valid.values
pc1 = df.loc[valid, "pc1"].values
y = (df.loc[valid, "compartment"] == "A").astype(int).values
print(f"Compartment labels: {mask.sum()} bins (A={y.sum()}, B={(~y.astype(bool)).sum()})")
grid = {
("real", "full"): args.full_real_dir,
("real", "local"): args.local_real_dir,
("real", "longrange"): args.longrange_real_dir,
("const", "full"): args.full_const_dir,
("const", "local"): args.local_const_dir,
("const", "longrange"): args.longrange_const_dir,
}
results = {}
for (feat, edge), d in grid.items():
key = f"{feat}_{edge}"
results[key] = evaluate_cell(d, pc1, y, mask)
r = results[key]
print(f"\n=== {key} ({d}) ===")
print(f" in_features={r['in_features']} hidden={r['hidden']} epochs={r['epochs_ran']}")
print(f" Link AUC {r['link_pred_test_auc']:.4f} Comp AUC {r['compartment_auc_5fold']:.4f} "
f"Spearman r {r['spearman_r_best']:.4f}")
with open(args.out, "w") as f:
json.dump(results, f, indent=2)
print(f"\nSaved → {args.out}")
# 2×3 table
print("\n=== Cross-ablation: compartment AUC (5-fold CV LogReg on embeddings) ===")
print(f"{'features ↓ / edges →':<22} {'full':>10} {'local':>10} {'longrange':>12}")
print("-" * 56)
for feat in ("real", "const"):
row = [results[f"{feat}_{e}"]["compartment_auc_5fold"] for e in ("full", "local", "longrange")]
print(f"{feat:<22} {row[0]:>10.4f} {row[1]:>10.4f} {row[2]:>12.4f}")
print("\n=== Feature contribution at fixed topology (real const) ===")
for e in ("full", "local", "longrange"):
d = results[f"real_{e}"]["compartment_auc_5fold"] - results[f"const_{e}"]["compartment_auc_5fold"]
print(f" {e:<10} Δ = {d:+.4f}")
print("\n=== Topology contribution at constant features (vs random ~0.50) ===")
for e in ("full", "local", "longrange"):
v = results[f"const_{e}"]["compartment_auc_5fold"]
d = v - 0.50
print(f" {e:<10} AUC = {v:.4f} Δ vs random = {d:+.4f}")
print("\n=== Interpretation ===")
lc = results["const_local"]["compartment_auc_5fold"]
rc = results["const_longrange"]["compartment_auc_5fold"]
if rc > lc + 0.02:
print(" → Long-range topology > local topology under constant features.")
print(" Original H3 result (local > longrange) was driven by feature smoothing,")
print(" not by topology. H3 is supported once features are controlled for.")
elif lc > rc + 0.02:
print(" → Local topology > long-range topology even under constant features.")
print(" H3 not supported. Local contacts dominate by topology, not by feature smoothing.")
else:
print(" → Local and long-range topology contribute equally under constant features.")
print(f" Δ = {rc - lc:+.4f} (within ±0.02 noise band).")
if __name__ == "__main__":
main()