155 lines
6.2 KiB
Python
155 lines
6.2 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
H3 ablation evaluation: compare full vs local-only vs longrange-only VGAE variants.
|
|
|
|
For each variant reports:
|
|
• Link prediction test AUC/AP (from variant's metrics.json)
|
|
• Compartment classification AUC (5-fold CV logistic regression on embeddings)
|
|
• Spearman r(PC1, best latent dim)
|
|
• Number of edges, epochs trained
|
|
|
|
The full variant is reused from H1; local- and longrange-only are H3 trainings.
|
|
|
|
Usage
|
|
-----
|
|
python experiments/h3_longrange/evaluate_ablation.py \\
|
|
--full_emb results/h1_representation/gm12878_emb.npy \\
|
|
--full_metrics results/h1_representation/metrics.json \\
|
|
--local_dir results/h3_longrange/local_only \\
|
|
--longrange_dir results/h3_longrange/longrange_only \\
|
|
--compartments results/h1_representation/compartments/gm12878_chr1.csv \\
|
|
--out results/h3_longrange/ablation_comparison.json
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
from scipy.stats import spearmanr
|
|
from sklearn.linear_model import LogisticRegression
|
|
from sklearn.metrics import roc_auc_score, average_precision_score
|
|
from sklearn.model_selection import StratifiedKFold
|
|
from sklearn.preprocessing import StandardScaler
|
|
|
|
|
|
def cv_auc(X, y, n_splits=5, seed=42):
|
|
skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=seed)
|
|
aucs, aps = [], []
|
|
for tr, te in skf.split(X, y):
|
|
sc = StandardScaler().fit(X[tr])
|
|
lr = LogisticRegression(max_iter=1000, random_state=seed)
|
|
lr.fit(sc.transform(X[tr]), y[tr])
|
|
prob = lr.predict_proba(sc.transform(X[te]))[:, 1]
|
|
aucs.append(roc_auc_score(y[te], prob))
|
|
aps.append(average_precision_score(y[te], prob))
|
|
return float(np.mean(aucs)), float(np.mean(aps))
|
|
|
|
|
|
def best_spearman(emb, pc1):
|
|
rs = [abs(spearmanr(emb[:, d], pc1).statistic) for d in range(emb.shape[1])]
|
|
return float(np.max(rs)), int(np.argmax(rs))
|
|
|
|
|
|
def evaluate_one(emb, metrics, pc1, y, mask):
|
|
emb_m = emb[mask]
|
|
auc, ap = cv_auc(emb_m, y)
|
|
r, dim = best_spearman(emb_m, pc1)
|
|
return {
|
|
"link_pred_test_auc": metrics.get("test_auc"),
|
|
"link_pred_test_ap": metrics.get("test_ap"),
|
|
"compartment_auc_5fold": auc,
|
|
"compartment_ap_5fold": ap,
|
|
"spearman_r_best": r,
|
|
"spearman_best_dim": dim,
|
|
"epochs_ran": metrics.get("epochs_ran"),
|
|
}
|
|
|
|
|
|
def _load_variant(emb_path, metrics_path):
|
|
emb = np.load(emb_path)
|
|
with open(metrics_path) as f:
|
|
m = json.load(f)
|
|
return emb, m
|
|
|
|
|
|
def main():
|
|
ap = argparse.ArgumentParser(description=__doc__,
|
|
formatter_class=argparse.RawDescriptionHelpFormatter)
|
|
ap.add_argument("--full_emb", required=True,
|
|
help="Path to full-model emb.npy (e.g. gm12878_emb.npy from H1)")
|
|
ap.add_argument("--full_metrics", required=True,
|
|
help="Path to full-model metrics.json (from H1)")
|
|
ap.add_argument("--local_dir", required=True,
|
|
help="Dir containing local-only emb.npy and metrics.json")
|
|
ap.add_argument("--longrange_dir", required=True,
|
|
help="Dir containing longrange-only emb.npy and metrics.json")
|
|
ap.add_argument("--compartments", required=True,
|
|
help="GM12878 compartment CSV (with pc1, compartment columns)")
|
|
ap.add_argument("--out", required=True)
|
|
args = ap.parse_args()
|
|
|
|
df = pd.read_csv(args.compartments)
|
|
valid = df["compartment"].isin(["A", "B"]) & df["pc1"].notna()
|
|
mask = valid.values
|
|
pc1 = df.loc[valid, "pc1"].values
|
|
y = (df.loc[valid, "compartment"] == "A").astype(int).values
|
|
print(f"Compartment labels: {mask.sum()} bins with A/B labels "
|
|
f"(A={y.sum()}, B={(~y.astype(bool)).sum()})")
|
|
|
|
variants = {
|
|
"full": (args.full_emb,
|
|
args.full_metrics),
|
|
"local_only": (os.path.join(args.local_dir, "emb.npy"),
|
|
os.path.join(args.local_dir, "metrics.json")),
|
|
"longrange_only": (os.path.join(args.longrange_dir, "emb.npy"),
|
|
os.path.join(args.longrange_dir, "metrics.json")),
|
|
}
|
|
|
|
results = {}
|
|
for name, (emb_p, met_p) in variants.items():
|
|
emb, mets = _load_variant(emb_p, met_p)
|
|
results[name] = evaluate_one(emb, mets, pc1, y, mask)
|
|
r = results[name]
|
|
print(f"\n=== {name} ===")
|
|
print(f" Link prediction: AUC {r['link_pred_test_auc']:.4f} "
|
|
f"AP {r['link_pred_test_ap']:.4f}")
|
|
print(f" Compartment recovery: AUC {r['compartment_auc_5fold']:.4f} "
|
|
f"AP {r['compartment_ap_5fold']:.4f}")
|
|
print(f" Spearman r(PC1, dim {r['spearman_best_dim']}): {r['spearman_r_best']:.4f}")
|
|
print(f" Epochs ran: {r['epochs_ran']}")
|
|
|
|
with open(args.out, "w") as f:
|
|
json.dump(results, f, indent=2)
|
|
print(f"\nSaved → {args.out}")
|
|
|
|
# Summary table
|
|
print("\n=== H3 Ablation Summary ===")
|
|
print(f"{'Variant':<20} {'Link AUC':>10} {'Comp AUC':>10} {'Spearman':>10} {'Epochs':>8}")
|
|
print("-" * 62)
|
|
for name, r in results.items():
|
|
print(f"{name:<20} {r['link_pred_test_auc']:>10.4f} "
|
|
f"{r['compartment_auc_5fold']:>10.4f} "
|
|
f"{r['spearman_r_best']:>10.4f} "
|
|
f"{r['epochs_ran']:>8d}")
|
|
|
|
# Interpretation
|
|
print("\n=== Interpretation ===")
|
|
full_c = results["full"]["compartment_auc_5fold"]
|
|
local_c = results["local_only"]["compartment_auc_5fold"]
|
|
longr_c = results["longrange_only"]["compartment_auc_5fold"]
|
|
print(f" Full compartment AUC: {full_c:.4f}")
|
|
print(f" Local-only compartment AUC: {local_c:.4f} (Δ={local_c-full_c:+.4f} vs full)")
|
|
print(f" Long-range only compartment AUC: {longr_c:.4f} (Δ={longr_c-full_c:+.4f} vs full)")
|
|
if longr_c > local_c:
|
|
print(" → Long-range contacts encode MORE compartment structure than local contacts (H3 supported).")
|
|
elif local_c > longr_c:
|
|
print(" → Local contacts dominate compartment signal (H3 not supported).")
|
|
else:
|
|
print(" → Local and long-range contribute equally.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|