#!/usr/bin/env python3 """ H3 ablation evaluation: compare full vs local-only vs longrange-only VGAE variants. For each variant reports: • Link prediction test AUC/AP (from variant's metrics.json) • Compartment classification AUC (5-fold CV logistic regression on embeddings) • Spearman r(PC1, best latent dim) • Number of edges, epochs trained The full variant is reused from H1; local- and longrange-only are H3 trainings. Usage ----- python experiments/h3_longrange/evaluate_ablation.py \\ --full_emb results/h1_representation/gm12878_emb.npy \\ --full_metrics results/h1_representation/metrics.json \\ --local_dir results/h3_longrange/local_only \\ --longrange_dir results/h3_longrange/longrange_only \\ --compartments results/h1_representation/compartments/gm12878_chr1.csv \\ --out results/h3_longrange/ablation_comparison.json """ import argparse import json import os import numpy as np import pandas as pd from scipy.stats import spearmanr from sklearn.linear_model import LogisticRegression from sklearn.metrics import roc_auc_score, average_precision_score from sklearn.model_selection import StratifiedKFold from sklearn.preprocessing import StandardScaler def cv_auc(X, y, n_splits=5, seed=42): skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=seed) aucs, aps = [], [] for tr, te in skf.split(X, y): sc = StandardScaler().fit(X[tr]) lr = LogisticRegression(max_iter=1000, random_state=seed) lr.fit(sc.transform(X[tr]), y[tr]) prob = lr.predict_proba(sc.transform(X[te]))[:, 1] aucs.append(roc_auc_score(y[te], prob)) aps.append(average_precision_score(y[te], prob)) return float(np.mean(aucs)), float(np.mean(aps)) def best_spearman(emb, pc1): rs = [abs(spearmanr(emb[:, d], pc1).statistic) for d in range(emb.shape[1])] return float(np.max(rs)), int(np.argmax(rs)) def evaluate_one(emb, metrics, pc1, y, mask): emb_m = emb[mask] auc, ap = cv_auc(emb_m, y) r, dim = best_spearman(emb_m, pc1) return { "link_pred_test_auc": metrics.get("test_auc"), "link_pred_test_ap": metrics.get("test_ap"), "compartment_auc_5fold": auc, "compartment_ap_5fold": ap, "spearman_r_best": r, "spearman_best_dim": dim, "epochs_ran": metrics.get("epochs_ran"), } def _load_variant(emb_path, metrics_path): emb = np.load(emb_path) with open(metrics_path) as f: m = json.load(f) return emb, m def main(): ap = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) ap.add_argument("--full_emb", required=True, help="Path to full-model emb.npy (e.g. gm12878_emb.npy from H1)") ap.add_argument("--full_metrics", required=True, help="Path to full-model metrics.json (from H1)") ap.add_argument("--local_dir", required=True, help="Dir containing local-only emb.npy and metrics.json") ap.add_argument("--longrange_dir", required=True, help="Dir containing longrange-only emb.npy and metrics.json") ap.add_argument("--compartments", required=True, help="GM12878 compartment CSV (with pc1, compartment columns)") ap.add_argument("--out", required=True) args = ap.parse_args() df = pd.read_csv(args.compartments) valid = df["compartment"].isin(["A", "B"]) & df["pc1"].notna() mask = valid.values pc1 = df.loc[valid, "pc1"].values y = (df.loc[valid, "compartment"] == "A").astype(int).values print(f"Compartment labels: {mask.sum()} bins with A/B labels " f"(A={y.sum()}, B={(~y.astype(bool)).sum()})") variants = { "full": (args.full_emb, args.full_metrics), "local_only": (os.path.join(args.local_dir, "emb.npy"), os.path.join(args.local_dir, "metrics.json")), "longrange_only": (os.path.join(args.longrange_dir, "emb.npy"), os.path.join(args.longrange_dir, "metrics.json")), } results = {} for name, (emb_p, met_p) in variants.items(): emb, mets = _load_variant(emb_p, met_p) results[name] = evaluate_one(emb, mets, pc1, y, mask) r = results[name] print(f"\n=== {name} ===") print(f" Link prediction: AUC {r['link_pred_test_auc']:.4f} " f"AP {r['link_pred_test_ap']:.4f}") print(f" Compartment recovery: AUC {r['compartment_auc_5fold']:.4f} " f"AP {r['compartment_ap_5fold']:.4f}") print(f" Spearman r(PC1, dim {r['spearman_best_dim']}): {r['spearman_r_best']:.4f}") print(f" Epochs ran: {r['epochs_ran']}") with open(args.out, "w") as f: json.dump(results, f, indent=2) print(f"\nSaved → {args.out}") # Summary table print("\n=== H3 Ablation Summary ===") print(f"{'Variant':<20} {'Link AUC':>10} {'Comp AUC':>10} {'Spearman':>10} {'Epochs':>8}") print("-" * 62) for name, r in results.items(): print(f"{name:<20} {r['link_pred_test_auc']:>10.4f} " f"{r['compartment_auc_5fold']:>10.4f} " f"{r['spearman_r_best']:>10.4f} " f"{r['epochs_ran']:>8d}") # Interpretation print("\n=== Interpretation ===") full_c = results["full"]["compartment_auc_5fold"] local_c = results["local_only"]["compartment_auc_5fold"] longr_c = results["longrange_only"]["compartment_auc_5fold"] print(f" Full compartment AUC: {full_c:.4f}") print(f" Local-only compartment AUC: {local_c:.4f} (Δ={local_c-full_c:+.4f} vs full)") print(f" Long-range only compartment AUC: {longr_c:.4f} (Δ={longr_c-full_c:+.4f} vs full)") if longr_c > local_c: print(" → Long-range contacts encode MORE compartment structure than local contacts (H3 supported).") elif local_c > longr_c: print(" → Local contacts dominate compartment signal (H3 not supported).") else: print(" → Local and long-range contribute equally.") if __name__ == "__main__": main()