#!/usr/bin/env python3 """ H1 evaluation: quantify biological validity of VGAE embeddings. Runs five comparisons against A/B compartment labels: 1. VGAE embeddings (GM12878, in-domain, 5-fold CV logistic regression) 2. PCA baseline (PC1 from O/E Pearson correlation matrix — classical method) 3. Feature-only (PCA on raw ChIP-seq node features, no graph) 4. Zero-shot IMR90 VGAE (train LR on GM12878 emb, test on IMR90 emb) 5. Spearman r(PC1, VGAE latent dims) — alignment without supervision Usage ----- python experiments/h1_representation/evaluate.py \ --gm12878_emb results/h1_representation/gm12878_emb.npy \ --imr90_emb results/h1_representation/imr90_emb.npy \ --gm12878_graph data/processed/gm12878/chr1.pt \ --imr90_graph data/processed/imr90/chr1.pt \ --comp_gm12878 results/h1_representation/compartments/gm12878_chr1.csv \ --comp_imr90 results/h1_representation/compartments/imr90_chr1.csv \ --out results/h1_representation/evaluation.json """ import argparse import json import numpy as np import pandas as pd import torch from scipy.stats import spearmanr from sklearn.decomposition import PCA from sklearn.linear_model import LogisticRegression from sklearn.metrics import roc_auc_score, average_precision_score from sklearn.model_selection import StratifiedKFold from sklearn.preprocessing import StandardScaler # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def load_compartment_labels(csv_path): """Return (pc1, binary_label, valid_mask) aligned to bin index.""" df = pd.read_csv(csv_path) valid = (df["compartment"].isin(["A", "B"])) & (df["pc1"].notna()) mask = valid.values pc1 = df.loc[valid, "pc1"].values y = (df.loc[valid, "compartment"] == "A").astype(int).values return pc1, y, mask def cv_auc(X, y, n_splits=5, seed=42): """Stratified k-fold logistic regression AUC and AP.""" skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=seed) aucs, aps = [], [] for tr, te in skf.split(X, y): sc = StandardScaler().fit(X[tr]) lr = LogisticRegression(max_iter=1000, random_state=seed) lr.fit(sc.transform(X[tr]), y[tr]) prob = lr.predict_proba(sc.transform(X[te]))[:, 1] aucs.append(roc_auc_score(y[te], prob)) aps.append(average_precision_score(y[te], prob)) return float(np.mean(aucs)), float(np.mean(aps)) def zeroshot_auc(X_train, y_train, X_test, y_test, seed=42): """Train on source, evaluate on target (zero-shot transfer).""" sc = StandardScaler().fit(X_train) lr = LogisticRegression(max_iter=1000, random_state=seed) lr.fit(sc.transform(X_train), y_train) prob = lr.predict_proba(sc.transform(X_test))[:, 1] return float(roc_auc_score(y_test, prob)), float(average_precision_score(y_test, prob)) def pc1_auc(pc1, y): """PC1 as a continuous score — AUC of signed PC1 (A = positive).""" # Sign convention: A compartment = positive PC1 signed = pc1 if np.mean(pc1[y == 1]) > np.mean(pc1[y == 0]) else -pc1 return float(roc_auc_score(y, signed)), float(average_precision_score(y, signed)) def best_spearman(emb, pc1): """Max |Spearman r| between PC1 and any single latent dimension.""" rs = [abs(spearmanr(emb[:, d], pc1).statistic) for d in range(emb.shape[1])] best_dim = int(np.argmax(rs)) return float(np.max(rs)), best_dim # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main(): ap = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) ap.add_argument("--gm12878_emb", required=True) ap.add_argument("--imr90_emb", required=True) ap.add_argument("--gm12878_graph", required=True) ap.add_argument("--imr90_graph", required=True) ap.add_argument("--comp_gm12878", required=True) ap.add_argument("--comp_imr90", required=True) ap.add_argument("--out", required=True) ap.add_argument("--n_splits", type=int, default=5) ap.add_argument("--seed", type=int, default=42) args = ap.parse_args() # ── Load ───────────────────────────────────────────────────────────────── gm_emb = np.load(args.gm12878_emb) ir_emb = np.load(args.imr90_emb) gm_graph = torch.load(args.gm12878_graph, weights_only=False) ir_graph = torch.load(args.imr90_graph, weights_only=False) gm_pc1, gm_y, gm_mask = load_compartment_labels(args.comp_gm12878) ir_pc1, ir_y, ir_mask = load_compartment_labels(args.comp_imr90) print(f"GM12878 — bins with A/B labels: {gm_mask.sum()} " f"(A={gm_y.sum()}, B={(~gm_y.astype(bool)).sum()})") print(f"IMR90 — bins with A/B labels: {ir_mask.sum()} " f"(A={ir_y.sum()}, B={(~ir_y.astype(bool)).sum()})") # Masked embeddings (only labelled bins) gm_emb_m = gm_emb[gm_mask] ir_emb_m = ir_emb[ir_mask] # ChIP-seq node features gm_feat = gm_graph.x.numpy()[gm_mask] ir_feat = ir_graph.x.numpy()[ir_mask] results = {} # ── 1. VGAE GM12878 in-domain ──────────────────────────────────────────── print("\n[1] VGAE GM12878 in-domain (5-fold CV)...") auc, ap = cv_auc(gm_emb_m, gm_y, n_splits=args.n_splits, seed=args.seed) results["vgae_gm12878_auc"] = auc results["vgae_gm12878_ap"] = ap print(f" AUC={auc:.4f} AP={ap:.4f}") # ── 2. PCA baseline (PC1 from O/E Pearson — classical method) ──────────── print("\n[2] PCA baseline (Hi-C PC1 → A/B)...") auc_pca, ap_pca = pc1_auc(gm_pc1, gm_y) results["pca_baseline_auc"] = auc_pca results["pca_baseline_ap"] = ap_pca print(f" AUC={auc_pca:.4f} AP={ap_pca:.4f}") # ── 3. Feature-only (ChIP-seq PCA, no graph) ───────────────────────────── print("\n[3] Feature-only (ChIP-seq PCA, no graph)...") n_components = min(gm_feat.shape[1], 8) feat_pca = PCA(n_components=n_components, random_state=args.seed).fit_transform(gm_feat) auc_feat, ap_feat = cv_auc(feat_pca, gm_y, n_splits=args.n_splits, seed=args.seed) results["feature_only_auc"] = auc_feat results["feature_only_ap"] = ap_feat print(f" AUC={auc_feat:.4f} AP={ap_feat:.4f}") # ── 4. Zero-shot IMR90 ──────────────────────────────────────────────────── print("\n[4] Zero-shot IMR90 (train on GM12878 emb, test on IMR90 emb)...") auc_zs, ap_zs = zeroshot_auc(gm_emb_m, gm_y, ir_emb_m, ir_y, seed=args.seed) results["vgae_imr90_zeroshot_auc"] = auc_zs results["vgae_imr90_zeroshot_ap"] = ap_zs print(f" AUC={auc_zs:.4f} AP={ap_zs:.4f}") # IMR90 in-domain for reference auc_ir, ap_ir = cv_auc(ir_emb_m, ir_y, n_splits=args.n_splits, seed=args.seed) results["vgae_imr90_indomain_auc"] = auc_ir results["vgae_imr90_indomain_ap"] = ap_ir print(f" IMR90 in-domain (CV): AUC={auc_ir:.4f} AP={ap_ir:.4f}") # ── 5. Spearman r(PC1, latent dims) ────────────────────────────────────── print("\n[5] Spearman r(PC1, VGAE latent dims)...") r_gm, dim_gm = best_spearman(gm_emb_m, gm_pc1) r_ir, dim_ir = best_spearman(ir_emb_m, ir_pc1) results["spearman_r_gm12878"] = r_gm results["spearman_best_dim_gm"] = dim_gm results["spearman_r_imr90"] = r_ir results["spearman_best_dim_imr90"]= dim_ir print(f" GM12878: |r|={r_gm:.4f} (dim {dim_gm})") print(f" IMR90: |r|={r_ir:.4f} (dim {dim_ir})") # ── Summary ─────────────────────────────────────────────────────────────── print("\n=== Summary ===") print(f" PCA baseline (Hi-C PC1): AUC {results['pca_baseline_auc']:.3f}") print(f" Feature-only (ChIP PCA): AUC {results['feature_only_auc']:.3f}") print(f" VGAE GM12878 (in-domain): AUC {results['vgae_gm12878_auc']:.3f}") print(f" VGAE IMR90 (zero-shot): AUC {results['vgae_imr90_zeroshot_auc']:.3f}") print(f" Spearman r GM12878: {results['spearman_r_gm12878']:.3f}") with open(args.out, "w") as f: json.dump(results, f, indent=2) print(f"\nSaved → {args.out}") if __name__ == "__main__": main()