194 lines
8.8 KiB
Python
194 lines
8.8 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
H1 evaluation: quantify biological validity of VGAE embeddings.
|
|
|
|
Runs five comparisons against A/B compartment labels:
|
|
1. VGAE embeddings (GM12878, in-domain, 5-fold CV logistic regression)
|
|
2. PCA baseline (PC1 from O/E Pearson correlation matrix — classical method)
|
|
3. Feature-only (PCA on raw ChIP-seq node features, no graph)
|
|
4. Zero-shot IMR90 VGAE (train LR on GM12878 emb, test on IMR90 emb)
|
|
5. Spearman r(PC1, VGAE latent dims) — alignment without supervision
|
|
|
|
Usage
|
|
-----
|
|
python experiments/h1_representation/evaluate.py \
|
|
--gm12878_emb results/h1_representation/gm12878_emb.npy \
|
|
--imr90_emb results/h1_representation/imr90_emb.npy \
|
|
--gm12878_graph data/processed/gm12878/chr1.pt \
|
|
--imr90_graph data/processed/imr90/chr1.pt \
|
|
--comp_gm12878 results/h1_representation/compartments/gm12878_chr1.csv \
|
|
--comp_imr90 results/h1_representation/compartments/imr90_chr1.csv \
|
|
--out results/h1_representation/evaluation.json
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
import torch
|
|
from scipy.stats import spearmanr
|
|
from sklearn.decomposition import PCA
|
|
from sklearn.linear_model import LogisticRegression
|
|
from sklearn.metrics import roc_auc_score, average_precision_score
|
|
from sklearn.model_selection import StratifiedKFold
|
|
from sklearn.preprocessing import StandardScaler
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def load_compartment_labels(csv_path):
|
|
"""Return (pc1, binary_label, valid_mask) aligned to bin index."""
|
|
df = pd.read_csv(csv_path)
|
|
valid = (df["compartment"].isin(["A", "B"])) & (df["pc1"].notna())
|
|
mask = valid.values
|
|
pc1 = df.loc[valid, "pc1"].values
|
|
y = (df.loc[valid, "compartment"] == "A").astype(int).values
|
|
return pc1, y, mask
|
|
|
|
|
|
def cv_auc(X, y, n_splits=5, seed=42):
|
|
"""Stratified k-fold logistic regression AUC and AP."""
|
|
skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=seed)
|
|
aucs, aps = [], []
|
|
for tr, te in skf.split(X, y):
|
|
sc = StandardScaler().fit(X[tr])
|
|
lr = LogisticRegression(max_iter=1000, random_state=seed)
|
|
lr.fit(sc.transform(X[tr]), y[tr])
|
|
prob = lr.predict_proba(sc.transform(X[te]))[:, 1]
|
|
aucs.append(roc_auc_score(y[te], prob))
|
|
aps.append(average_precision_score(y[te], prob))
|
|
return float(np.mean(aucs)), float(np.mean(aps))
|
|
|
|
|
|
def zeroshot_auc(X_train, y_train, X_test, y_test, seed=42):
|
|
"""Train on source, evaluate on target (zero-shot transfer)."""
|
|
sc = StandardScaler().fit(X_train)
|
|
lr = LogisticRegression(max_iter=1000, random_state=seed)
|
|
lr.fit(sc.transform(X_train), y_train)
|
|
prob = lr.predict_proba(sc.transform(X_test))[:, 1]
|
|
return float(roc_auc_score(y_test, prob)), float(average_precision_score(y_test, prob))
|
|
|
|
|
|
def pc1_auc(pc1, y):
|
|
"""PC1 as a continuous score — AUC of signed PC1 (A = positive)."""
|
|
# Sign convention: A compartment = positive PC1
|
|
signed = pc1 if np.mean(pc1[y == 1]) > np.mean(pc1[y == 0]) else -pc1
|
|
return float(roc_auc_score(y, signed)), float(average_precision_score(y, signed))
|
|
|
|
|
|
def best_spearman(emb, pc1):
|
|
"""Max |Spearman r| between PC1 and any single latent dimension."""
|
|
rs = [abs(spearmanr(emb[:, d], pc1).statistic) for d in range(emb.shape[1])]
|
|
best_dim = int(np.argmax(rs))
|
|
return float(np.max(rs)), best_dim
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Main
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def main():
|
|
ap = argparse.ArgumentParser(description=__doc__,
|
|
formatter_class=argparse.RawDescriptionHelpFormatter)
|
|
ap.add_argument("--gm12878_emb", required=True)
|
|
ap.add_argument("--imr90_emb", required=True)
|
|
ap.add_argument("--gm12878_graph", required=True)
|
|
ap.add_argument("--imr90_graph", required=True)
|
|
ap.add_argument("--comp_gm12878", required=True)
|
|
ap.add_argument("--comp_imr90", required=True)
|
|
ap.add_argument("--out", required=True)
|
|
ap.add_argument("--n_splits", type=int, default=5)
|
|
ap.add_argument("--seed", type=int, default=42)
|
|
args = ap.parse_args()
|
|
|
|
# ── Load ─────────────────────────────────────────────────────────────────
|
|
gm_emb = np.load(args.gm12878_emb)
|
|
ir_emb = np.load(args.imr90_emb)
|
|
|
|
gm_graph = torch.load(args.gm12878_graph, weights_only=False)
|
|
ir_graph = torch.load(args.imr90_graph, weights_only=False)
|
|
|
|
gm_pc1, gm_y, gm_mask = load_compartment_labels(args.comp_gm12878)
|
|
ir_pc1, ir_y, ir_mask = load_compartment_labels(args.comp_imr90)
|
|
|
|
print(f"GM12878 — bins with A/B labels: {gm_mask.sum()} "
|
|
f"(A={gm_y.sum()}, B={(~gm_y.astype(bool)).sum()})")
|
|
print(f"IMR90 — bins with A/B labels: {ir_mask.sum()} "
|
|
f"(A={ir_y.sum()}, B={(~ir_y.astype(bool)).sum()})")
|
|
|
|
# Masked embeddings (only labelled bins)
|
|
gm_emb_m = gm_emb[gm_mask]
|
|
ir_emb_m = ir_emb[ir_mask]
|
|
|
|
# ChIP-seq node features
|
|
gm_feat = gm_graph.x.numpy()[gm_mask]
|
|
ir_feat = ir_graph.x.numpy()[ir_mask]
|
|
|
|
results = {}
|
|
|
|
# ── 1. VGAE GM12878 in-domain ────────────────────────────────────────────
|
|
print("\n[1] VGAE GM12878 in-domain (5-fold CV)...")
|
|
auc, ap = cv_auc(gm_emb_m, gm_y, n_splits=args.n_splits, seed=args.seed)
|
|
results["vgae_gm12878_auc"] = auc
|
|
results["vgae_gm12878_ap"] = ap
|
|
print(f" AUC={auc:.4f} AP={ap:.4f}")
|
|
|
|
# ── 2. PCA baseline (PC1 from O/E Pearson — classical method) ────────────
|
|
print("\n[2] PCA baseline (Hi-C PC1 → A/B)...")
|
|
auc_pca, ap_pca = pc1_auc(gm_pc1, gm_y)
|
|
results["pca_baseline_auc"] = auc_pca
|
|
results["pca_baseline_ap"] = ap_pca
|
|
print(f" AUC={auc_pca:.4f} AP={ap_pca:.4f}")
|
|
|
|
# ── 3. Feature-only (ChIP-seq PCA, no graph) ─────────────────────────────
|
|
print("\n[3] Feature-only (ChIP-seq PCA, no graph)...")
|
|
n_components = min(gm_feat.shape[1], 8)
|
|
feat_pca = PCA(n_components=n_components, random_state=args.seed).fit_transform(gm_feat)
|
|
auc_feat, ap_feat = cv_auc(feat_pca, gm_y, n_splits=args.n_splits, seed=args.seed)
|
|
results["feature_only_auc"] = auc_feat
|
|
results["feature_only_ap"] = ap_feat
|
|
print(f" AUC={auc_feat:.4f} AP={ap_feat:.4f}")
|
|
|
|
# ── 4. Zero-shot IMR90 ────────────────────────────────────────────────────
|
|
print("\n[4] Zero-shot IMR90 (train on GM12878 emb, test on IMR90 emb)...")
|
|
auc_zs, ap_zs = zeroshot_auc(gm_emb_m, gm_y, ir_emb_m, ir_y, seed=args.seed)
|
|
results["vgae_imr90_zeroshot_auc"] = auc_zs
|
|
results["vgae_imr90_zeroshot_ap"] = ap_zs
|
|
print(f" AUC={auc_zs:.4f} AP={ap_zs:.4f}")
|
|
|
|
# IMR90 in-domain for reference
|
|
auc_ir, ap_ir = cv_auc(ir_emb_m, ir_y, n_splits=args.n_splits, seed=args.seed)
|
|
results["vgae_imr90_indomain_auc"] = auc_ir
|
|
results["vgae_imr90_indomain_ap"] = ap_ir
|
|
print(f" IMR90 in-domain (CV): AUC={auc_ir:.4f} AP={ap_ir:.4f}")
|
|
|
|
# ── 5. Spearman r(PC1, latent dims) ──────────────────────────────────────
|
|
print("\n[5] Spearman r(PC1, VGAE latent dims)...")
|
|
r_gm, dim_gm = best_spearman(gm_emb_m, gm_pc1)
|
|
r_ir, dim_ir = best_spearman(ir_emb_m, ir_pc1)
|
|
results["spearman_r_gm12878"] = r_gm
|
|
results["spearman_best_dim_gm"] = dim_gm
|
|
results["spearman_r_imr90"] = r_ir
|
|
results["spearman_best_dim_imr90"]= dim_ir
|
|
print(f" GM12878: |r|={r_gm:.4f} (dim {dim_gm})")
|
|
print(f" IMR90: |r|={r_ir:.4f} (dim {dim_ir})")
|
|
|
|
# ── Summary ───────────────────────────────────────────────────────────────
|
|
print("\n=== Summary ===")
|
|
print(f" PCA baseline (Hi-C PC1): AUC {results['pca_baseline_auc']:.3f}")
|
|
print(f" Feature-only (ChIP PCA): AUC {results['feature_only_auc']:.3f}")
|
|
print(f" VGAE GM12878 (in-domain): AUC {results['vgae_gm12878_auc']:.3f}")
|
|
print(f" VGAE IMR90 (zero-shot): AUC {results['vgae_imr90_zeroshot_auc']:.3f}")
|
|
print(f" Spearman r GM12878: {results['spearman_r_gm12878']:.3f}")
|
|
|
|
with open(args.out, "w") as f:
|
|
json.dump(results, f, indent=2)
|
|
print(f"\nSaved → {args.out}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|