Files
chromatin-vgae-hic/experiments/h3_longrange/evaluate_ablation.py

155 lines
6.2 KiB
Python

#!/usr/bin/env python3
"""
H3 ablation evaluation: compare full vs local-only vs longrange-only VGAE variants.
For each variant reports:
• Link prediction test AUC/AP (from variant's metrics.json)
• Compartment classification AUC (5-fold CV logistic regression on embeddings)
• Spearman r(PC1, best latent dim)
• Number of edges, epochs trained
The full variant is reused from H1; local- and longrange-only are H3 trainings.
Usage
-----
python experiments/h3_longrange/evaluate_ablation.py \\
--full_emb results/h1_representation/gm12878_emb.npy \\
--full_metrics results/h1_representation/metrics.json \\
--local_dir results/h3_longrange/local_only \\
--longrange_dir results/h3_longrange/longrange_only \\
--compartments results/h1_representation/compartments/gm12878_chr1.csv \\
--out results/h3_longrange/ablation_comparison.json
"""
import argparse
import json
import os
import numpy as np
import pandas as pd
from scipy.stats import spearmanr
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score, average_precision_score
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import StandardScaler
def cv_auc(X, y, n_splits=5, seed=42):
skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=seed)
aucs, aps = [], []
for tr, te in skf.split(X, y):
sc = StandardScaler().fit(X[tr])
lr = LogisticRegression(max_iter=1000, random_state=seed)
lr.fit(sc.transform(X[tr]), y[tr])
prob = lr.predict_proba(sc.transform(X[te]))[:, 1]
aucs.append(roc_auc_score(y[te], prob))
aps.append(average_precision_score(y[te], prob))
return float(np.mean(aucs)), float(np.mean(aps))
def best_spearman(emb, pc1):
rs = [abs(spearmanr(emb[:, d], pc1).statistic) for d in range(emb.shape[1])]
return float(np.max(rs)), int(np.argmax(rs))
def evaluate_one(emb, metrics, pc1, y, mask):
emb_m = emb[mask]
auc, ap = cv_auc(emb_m, y)
r, dim = best_spearman(emb_m, pc1)
return {
"link_pred_test_auc": metrics.get("test_auc"),
"link_pred_test_ap": metrics.get("test_ap"),
"compartment_auc_5fold": auc,
"compartment_ap_5fold": ap,
"spearman_r_best": r,
"spearman_best_dim": dim,
"epochs_ran": metrics.get("epochs_ran"),
}
def _load_variant(emb_path, metrics_path):
emb = np.load(emb_path)
with open(metrics_path) as f:
m = json.load(f)
return emb, m
def main():
ap = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter)
ap.add_argument("--full_emb", required=True,
help="Path to full-model emb.npy (e.g. gm12878_emb.npy from H1)")
ap.add_argument("--full_metrics", required=True,
help="Path to full-model metrics.json (from H1)")
ap.add_argument("--local_dir", required=True,
help="Dir containing local-only emb.npy and metrics.json")
ap.add_argument("--longrange_dir", required=True,
help="Dir containing longrange-only emb.npy and metrics.json")
ap.add_argument("--compartments", required=True,
help="GM12878 compartment CSV (with pc1, compartment columns)")
ap.add_argument("--out", required=True)
args = ap.parse_args()
df = pd.read_csv(args.compartments)
valid = df["compartment"].isin(["A", "B"]) & df["pc1"].notna()
mask = valid.values
pc1 = df.loc[valid, "pc1"].values
y = (df.loc[valid, "compartment"] == "A").astype(int).values
print(f"Compartment labels: {mask.sum()} bins with A/B labels "
f"(A={y.sum()}, B={(~y.astype(bool)).sum()})")
variants = {
"full": (args.full_emb,
args.full_metrics),
"local_only": (os.path.join(args.local_dir, "emb.npy"),
os.path.join(args.local_dir, "metrics.json")),
"longrange_only": (os.path.join(args.longrange_dir, "emb.npy"),
os.path.join(args.longrange_dir, "metrics.json")),
}
results = {}
for name, (emb_p, met_p) in variants.items():
emb, mets = _load_variant(emb_p, met_p)
results[name] = evaluate_one(emb, mets, pc1, y, mask)
r = results[name]
print(f"\n=== {name} ===")
print(f" Link prediction: AUC {r['link_pred_test_auc']:.4f} "
f"AP {r['link_pred_test_ap']:.4f}")
print(f" Compartment recovery: AUC {r['compartment_auc_5fold']:.4f} "
f"AP {r['compartment_ap_5fold']:.4f}")
print(f" Spearman r(PC1, dim {r['spearman_best_dim']}): {r['spearman_r_best']:.4f}")
print(f" Epochs ran: {r['epochs_ran']}")
with open(args.out, "w") as f:
json.dump(results, f, indent=2)
print(f"\nSaved → {args.out}")
# Summary table
print("\n=== H3 Ablation Summary ===")
print(f"{'Variant':<20} {'Link AUC':>10} {'Comp AUC':>10} {'Spearman':>10} {'Epochs':>8}")
print("-" * 62)
for name, r in results.items():
print(f"{name:<20} {r['link_pred_test_auc']:>10.4f} "
f"{r['compartment_auc_5fold']:>10.4f} "
f"{r['spearman_r_best']:>10.4f} "
f"{r['epochs_ran']:>8d}")
# Interpretation
print("\n=== Interpretation ===")
full_c = results["full"]["compartment_auc_5fold"]
local_c = results["local_only"]["compartment_auc_5fold"]
longr_c = results["longrange_only"]["compartment_auc_5fold"]
print(f" Full compartment AUC: {full_c:.4f}")
print(f" Local-only compartment AUC: {local_c:.4f} (Δ={local_c-full_c:+.4f} vs full)")
print(f" Long-range only compartment AUC: {longr_c:.4f} (Δ={longr_c-full_c:+.4f} vs full)")
if longr_c > local_c:
print(" → Long-range contacts encode MORE compartment structure than local contacts (H3 supported).")
elif local_c > longr_c:
print(" → Local contacts dominate compartment signal (H3 not supported).")
else:
print(" → Local and long-range contribute equally.")
if __name__ == "__main__":
main()