183 lines
7.5 KiB
Python
183 lines
7.5 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
H3 feature × edge cross-ablation evaluation.
|
||
|
||
Builds a 2×3 grid:
|
||
|
||
full edges local (<250 kb) long-range (>1 Mb)
|
||
real features A B C
|
||
constant features D E F
|
||
|
||
A–C come from the existing H1/H3 trainings (real features × edge subset).
|
||
D–F are new trainings with `--constant_features` on the same edge subsets.
|
||
|
||
For each cell we report:
|
||
• Link prediction test AUC/AP (from cell's metrics.json)
|
||
• Compartment classification AUC/AP (5-fold CV LogReg on embeddings)
|
||
• Spearman r(PC1, best latent dim)
|
||
|
||
The cross-ablation separates two confounded sources of signal in H3:
|
||
• Feature signal: epigenetic marks themselves correlate with compartment;
|
||
message-passing smooths them over adjacent bins regardless of topology.
|
||
• Topology signal: information genuinely carried by graph structure.
|
||
|
||
Reading the table:
|
||
• Δ (real − const) at fixed edges = contribution of features given that topology.
|
||
• Δ across edges at fixed features = contribution of edge-band topology.
|
||
• The bottom row alone shows what each edge-band carries from topology alone.
|
||
|
||
Usage
|
||
-----
|
||
python experiments/h3_longrange/evaluate_cross_ablation.py \\
|
||
--full_real_dir results/h1_representation \\
|
||
--local_real_dir results/h3_longrange/local_only \\
|
||
--longrange_real_dir results/h3_longrange/longrange_only \\
|
||
--full_const_dir results/h3_longrange/full_const \\
|
||
--local_const_dir results/h3_longrange/local_const \\
|
||
--longrange_const_dir results/h3_longrange/longrange_const \\
|
||
--compartments results/h1_representation/compartments/gm12878_chr1.csv \\
|
||
--out results/h3_longrange/cross_ablation.json
|
||
"""
|
||
|
||
import argparse
|
||
import json
|
||
import os
|
||
|
||
import numpy as np
|
||
import pandas as pd
|
||
from scipy.stats import spearmanr
|
||
from sklearn.linear_model import LogisticRegression
|
||
from sklearn.metrics import roc_auc_score, average_precision_score
|
||
from sklearn.model_selection import StratifiedKFold
|
||
from sklearn.preprocessing import StandardScaler
|
||
|
||
|
||
def cv_auc(X, y, n_splits=5, seed=42):
|
||
skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=seed)
|
||
aucs, aps = [], []
|
||
for tr, te in skf.split(X, y):
|
||
sc = StandardScaler().fit(X[tr])
|
||
lr = LogisticRegression(max_iter=1000, random_state=seed)
|
||
lr.fit(sc.transform(X[tr]), y[tr])
|
||
prob = lr.predict_proba(sc.transform(X[te]))[:, 1]
|
||
aucs.append(roc_auc_score(y[te], prob))
|
||
aps.append(average_precision_score(y[te], prob))
|
||
return float(np.mean(aucs)), float(np.mean(aps))
|
||
|
||
|
||
def best_spearman(emb, pc1):
|
||
rs = [abs(spearmanr(emb[:, d], pc1).statistic) for d in range(emb.shape[1])]
|
||
return float(np.max(rs)), int(np.argmax(rs))
|
||
|
||
|
||
def _emb_name(d):
|
||
# H1 uses gm12878_emb.npy; H3 cells use emb.npy
|
||
for name in ("emb.npy", "gm12878_emb.npy"):
|
||
p = os.path.join(d, name)
|
||
if os.path.exists(p):
|
||
return p
|
||
raise FileNotFoundError(f"No emb file in {d}")
|
||
|
||
|
||
def evaluate_cell(emb_dir, pc1, y, mask):
|
||
emb = np.load(_emb_name(emb_dir))
|
||
with open(os.path.join(emb_dir, "metrics.json")) as f:
|
||
m = json.load(f)
|
||
emb_m = emb[mask]
|
||
auc, ap = cv_auc(emb_m, y)
|
||
r, dim = best_spearman(emb_m, pc1)
|
||
return {
|
||
"link_pred_test_auc": m.get("test_auc"),
|
||
"link_pred_test_ap": m.get("test_ap"),
|
||
"compartment_auc_5fold": auc,
|
||
"compartment_ap_5fold": ap,
|
||
"spearman_r_best": r,
|
||
"spearman_best_dim": dim,
|
||
"epochs_ran": m.get("epochs_ran"),
|
||
"hidden": m.get("hidden"),
|
||
"latent": m.get("latent"),
|
||
"constant_features": m.get("constant_features", False),
|
||
"in_features": m.get("in_features"),
|
||
}
|
||
|
||
|
||
def main():
|
||
ap = argparse.ArgumentParser(description=__doc__,
|
||
formatter_class=argparse.RawDescriptionHelpFormatter)
|
||
ap.add_argument("--full_real_dir", required=True)
|
||
ap.add_argument("--local_real_dir", required=True)
|
||
ap.add_argument("--longrange_real_dir", required=True)
|
||
ap.add_argument("--full_const_dir", required=True)
|
||
ap.add_argument("--local_const_dir", required=True)
|
||
ap.add_argument("--longrange_const_dir", required=True)
|
||
ap.add_argument("--compartments", required=True)
|
||
ap.add_argument("--out", required=True)
|
||
args = ap.parse_args()
|
||
|
||
df = pd.read_csv(args.compartments)
|
||
valid = df["compartment"].isin(["A", "B"]) & df["pc1"].notna()
|
||
mask = valid.values
|
||
pc1 = df.loc[valid, "pc1"].values
|
||
y = (df.loc[valid, "compartment"] == "A").astype(int).values
|
||
print(f"Compartment labels: {mask.sum()} bins (A={y.sum()}, B={(~y.astype(bool)).sum()})")
|
||
|
||
grid = {
|
||
("real", "full"): args.full_real_dir,
|
||
("real", "local"): args.local_real_dir,
|
||
("real", "longrange"): args.longrange_real_dir,
|
||
("const", "full"): args.full_const_dir,
|
||
("const", "local"): args.local_const_dir,
|
||
("const", "longrange"): args.longrange_const_dir,
|
||
}
|
||
|
||
results = {}
|
||
for (feat, edge), d in grid.items():
|
||
key = f"{feat}_{edge}"
|
||
results[key] = evaluate_cell(d, pc1, y, mask)
|
||
r = results[key]
|
||
print(f"\n=== {key} ({d}) ===")
|
||
print(f" in_features={r['in_features']} hidden={r['hidden']} epochs={r['epochs_ran']}")
|
||
print(f" Link AUC {r['link_pred_test_auc']:.4f} Comp AUC {r['compartment_auc_5fold']:.4f} "
|
||
f"Spearman r {r['spearman_r_best']:.4f}")
|
||
|
||
with open(args.out, "w") as f:
|
||
json.dump(results, f, indent=2)
|
||
print(f"\nSaved → {args.out}")
|
||
|
||
# 2×3 table
|
||
print("\n=== Cross-ablation: compartment AUC (5-fold CV LogReg on embeddings) ===")
|
||
print(f"{'features ↓ / edges →':<22} {'full':>10} {'local':>10} {'longrange':>12}")
|
||
print("-" * 56)
|
||
for feat in ("real", "const"):
|
||
row = [results[f"{feat}_{e}"]["compartment_auc_5fold"] for e in ("full", "local", "longrange")]
|
||
print(f"{feat:<22} {row[0]:>10.4f} {row[1]:>10.4f} {row[2]:>12.4f}")
|
||
|
||
print("\n=== Feature contribution at fixed topology (real − const) ===")
|
||
for e in ("full", "local", "longrange"):
|
||
d = results[f"real_{e}"]["compartment_auc_5fold"] - results[f"const_{e}"]["compartment_auc_5fold"]
|
||
print(f" {e:<10} Δ = {d:+.4f}")
|
||
|
||
print("\n=== Topology contribution at constant features (vs random ~0.50) ===")
|
||
for e in ("full", "local", "longrange"):
|
||
v = results[f"const_{e}"]["compartment_auc_5fold"]
|
||
d = v - 0.50
|
||
print(f" {e:<10} AUC = {v:.4f} Δ vs random = {d:+.4f}")
|
||
|
||
print("\n=== Interpretation ===")
|
||
lc = results["const_local"]["compartment_auc_5fold"]
|
||
rc = results["const_longrange"]["compartment_auc_5fold"]
|
||
if rc > lc + 0.02:
|
||
print(" → Long-range topology > local topology under constant features.")
|
||
print(" Original H3 result (local > longrange) was driven by feature smoothing,")
|
||
print(" not by topology. H3 is supported once features are controlled for.")
|
||
elif lc > rc + 0.02:
|
||
print(" → Local topology > long-range topology even under constant features.")
|
||
print(" H3 not supported. Local contacts dominate by topology, not by feature smoothing.")
|
||
else:
|
||
print(" → Local and long-range topology contribute equally under constant features.")
|
||
print(f" Δ = {rc - lc:+.4f} (within ±0.02 noise band).")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|