Files
chromatin-vgae-hic/experiments/h1_representation/evaluate.py

194 lines
8.8 KiB
Python

#!/usr/bin/env python3
"""
H1 evaluation: quantify biological validity of VGAE embeddings.
Runs five comparisons against A/B compartment labels:
1. VGAE embeddings (GM12878, in-domain, 5-fold CV logistic regression)
2. PCA baseline (PC1 from O/E Pearson correlation matrix — classical method)
3. Feature-only (PCA on raw ChIP-seq node features, no graph)
4. Zero-shot IMR90 VGAE (train LR on GM12878 emb, test on IMR90 emb)
5. Spearman r(PC1, VGAE latent dims) — alignment without supervision
Usage
-----
python experiments/h1_representation/evaluate.py \
--gm12878_emb results/h1_representation/gm12878_emb.npy \
--imr90_emb results/h1_representation/imr90_emb.npy \
--gm12878_graph data/processed/gm12878/chr1.pt \
--imr90_graph data/processed/imr90/chr1.pt \
--comp_gm12878 results/h1_representation/compartments/gm12878_chr1.csv \
--comp_imr90 results/h1_representation/compartments/imr90_chr1.csv \
--out results/h1_representation/evaluation.json
"""
import argparse
import json
import numpy as np
import pandas as pd
import torch
from scipy.stats import spearmanr
from sklearn.decomposition import PCA
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score, average_precision_score
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import StandardScaler
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def load_compartment_labels(csv_path):
"""Return (pc1, binary_label, valid_mask) aligned to bin index."""
df = pd.read_csv(csv_path)
valid = (df["compartment"].isin(["A", "B"])) & (df["pc1"].notna())
mask = valid.values
pc1 = df.loc[valid, "pc1"].values
y = (df.loc[valid, "compartment"] == "A").astype(int).values
return pc1, y, mask
def cv_auc(X, y, n_splits=5, seed=42):
"""Stratified k-fold logistic regression AUC and AP."""
skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=seed)
aucs, aps = [], []
for tr, te in skf.split(X, y):
sc = StandardScaler().fit(X[tr])
lr = LogisticRegression(max_iter=1000, random_state=seed)
lr.fit(sc.transform(X[tr]), y[tr])
prob = lr.predict_proba(sc.transform(X[te]))[:, 1]
aucs.append(roc_auc_score(y[te], prob))
aps.append(average_precision_score(y[te], prob))
return float(np.mean(aucs)), float(np.mean(aps))
def zeroshot_auc(X_train, y_train, X_test, y_test, seed=42):
"""Train on source, evaluate on target (zero-shot transfer)."""
sc = StandardScaler().fit(X_train)
lr = LogisticRegression(max_iter=1000, random_state=seed)
lr.fit(sc.transform(X_train), y_train)
prob = lr.predict_proba(sc.transform(X_test))[:, 1]
return float(roc_auc_score(y_test, prob)), float(average_precision_score(y_test, prob))
def pc1_auc(pc1, y):
"""PC1 as a continuous score — AUC of signed PC1 (A = positive)."""
# Sign convention: A compartment = positive PC1
signed = pc1 if np.mean(pc1[y == 1]) > np.mean(pc1[y == 0]) else -pc1
return float(roc_auc_score(y, signed)), float(average_precision_score(y, signed))
def best_spearman(emb, pc1):
"""Max |Spearman r| between PC1 and any single latent dimension."""
rs = [abs(spearmanr(emb[:, d], pc1).statistic) for d in range(emb.shape[1])]
best_dim = int(np.argmax(rs))
return float(np.max(rs)), best_dim
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
ap = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter)
ap.add_argument("--gm12878_emb", required=True)
ap.add_argument("--imr90_emb", required=True)
ap.add_argument("--gm12878_graph", required=True)
ap.add_argument("--imr90_graph", required=True)
ap.add_argument("--comp_gm12878", required=True)
ap.add_argument("--comp_imr90", required=True)
ap.add_argument("--out", required=True)
ap.add_argument("--n_splits", type=int, default=5)
ap.add_argument("--seed", type=int, default=42)
args = ap.parse_args()
# ── Load ─────────────────────────────────────────────────────────────────
gm_emb = np.load(args.gm12878_emb)
ir_emb = np.load(args.imr90_emb)
gm_graph = torch.load(args.gm12878_graph, weights_only=False)
ir_graph = torch.load(args.imr90_graph, weights_only=False)
gm_pc1, gm_y, gm_mask = load_compartment_labels(args.comp_gm12878)
ir_pc1, ir_y, ir_mask = load_compartment_labels(args.comp_imr90)
print(f"GM12878 — bins with A/B labels: {gm_mask.sum()} "
f"(A={gm_y.sum()}, B={(~gm_y.astype(bool)).sum()})")
print(f"IMR90 — bins with A/B labels: {ir_mask.sum()} "
f"(A={ir_y.sum()}, B={(~ir_y.astype(bool)).sum()})")
# Masked embeddings (only labelled bins)
gm_emb_m = gm_emb[gm_mask]
ir_emb_m = ir_emb[ir_mask]
# ChIP-seq node features
gm_feat = gm_graph.x.numpy()[gm_mask]
ir_feat = ir_graph.x.numpy()[ir_mask]
results = {}
# ── 1. VGAE GM12878 in-domain ────────────────────────────────────────────
print("\n[1] VGAE GM12878 in-domain (5-fold CV)...")
auc, ap = cv_auc(gm_emb_m, gm_y, n_splits=args.n_splits, seed=args.seed)
results["vgae_gm12878_auc"] = auc
results["vgae_gm12878_ap"] = ap
print(f" AUC={auc:.4f} AP={ap:.4f}")
# ── 2. PCA baseline (PC1 from O/E Pearson — classical method) ────────────
print("\n[2] PCA baseline (Hi-C PC1 → A/B)...")
auc_pca, ap_pca = pc1_auc(gm_pc1, gm_y)
results["pca_baseline_auc"] = auc_pca
results["pca_baseline_ap"] = ap_pca
print(f" AUC={auc_pca:.4f} AP={ap_pca:.4f}")
# ── 3. Feature-only (ChIP-seq PCA, no graph) ─────────────────────────────
print("\n[3] Feature-only (ChIP-seq PCA, no graph)...")
n_components = min(gm_feat.shape[1], 8)
feat_pca = PCA(n_components=n_components, random_state=args.seed).fit_transform(gm_feat)
auc_feat, ap_feat = cv_auc(feat_pca, gm_y, n_splits=args.n_splits, seed=args.seed)
results["feature_only_auc"] = auc_feat
results["feature_only_ap"] = ap_feat
print(f" AUC={auc_feat:.4f} AP={ap_feat:.4f}")
# ── 4. Zero-shot IMR90 ────────────────────────────────────────────────────
print("\n[4] Zero-shot IMR90 (train on GM12878 emb, test on IMR90 emb)...")
auc_zs, ap_zs = zeroshot_auc(gm_emb_m, gm_y, ir_emb_m, ir_y, seed=args.seed)
results["vgae_imr90_zeroshot_auc"] = auc_zs
results["vgae_imr90_zeroshot_ap"] = ap_zs
print(f" AUC={auc_zs:.4f} AP={ap_zs:.4f}")
# IMR90 in-domain for reference
auc_ir, ap_ir = cv_auc(ir_emb_m, ir_y, n_splits=args.n_splits, seed=args.seed)
results["vgae_imr90_indomain_auc"] = auc_ir
results["vgae_imr90_indomain_ap"] = ap_ir
print(f" IMR90 in-domain (CV): AUC={auc_ir:.4f} AP={ap_ir:.4f}")
# ── 5. Spearman r(PC1, latent dims) ──────────────────────────────────────
print("\n[5] Spearman r(PC1, VGAE latent dims)...")
r_gm, dim_gm = best_spearman(gm_emb_m, gm_pc1)
r_ir, dim_ir = best_spearman(ir_emb_m, ir_pc1)
results["spearman_r_gm12878"] = r_gm
results["spearman_best_dim_gm"] = dim_gm
results["spearman_r_imr90"] = r_ir
results["spearman_best_dim_imr90"]= dim_ir
print(f" GM12878: |r|={r_gm:.4f} (dim {dim_gm})")
print(f" IMR90: |r|={r_ir:.4f} (dim {dim_ir})")
# ── Summary ───────────────────────────────────────────────────────────────
print("\n=== Summary ===")
print(f" PCA baseline (Hi-C PC1): AUC {results['pca_baseline_auc']:.3f}")
print(f" Feature-only (ChIP PCA): AUC {results['feature_only_auc']:.3f}")
print(f" VGAE GM12878 (in-domain): AUC {results['vgae_gm12878_auc']:.3f}")
print(f" VGAE IMR90 (zero-shot): AUC {results['vgae_imr90_zeroshot_auc']:.3f}")
print(f" Spearman r GM12878: {results['spearman_r_gm12878']:.3f}")
with open(args.out, "w") as f:
json.dump(results, f, indent=2)
print(f"\nSaved → {args.out}")
if __name__ == "__main__":
main()