v1.0.0: VGAE applied to GM12878 vs IMR90 chr21 Hi-C at 25kb

Full reproducible pipeline: .mcool + ChIP-seq bigwigs → latent
  embeddings → A/B compartment calls → cross-cell comparison.

  Key results (chr21, 25 kb, latent dim=32):
  - Test AUC=0.777, AP=0.759 (converged epoch 31/300)
  - GM12878 A/B silhouette (cosine) = 0.775
  - IMR90 zero-shot silhouette = 0.443
  - A-compartment bins stable across cell types (mean cosine Δ=0.042)
  - B-compartment bins shift substantially (mean cosine Δ=0.451)
  - 101 B→A and 70 A→B compartment switches GM12878→IMR90
This commit is contained in:
2026-05-15 01:53:04 +02:00
parent 6c91af655d
commit acadbd780c
27 changed files with 6764 additions and 201 deletions

48
.gitignore vendored
View File

@@ -1,30 +1,30 @@
# Raw sequencing and contact data (large files; download via run_pipeline.sh)
data/raw/
# Python
__pycache__/
*.pyc
*.py[cod]
*.pyo
.pytest_cache/
*.egg-info/
dist/
build/
# Conda / mamba envs
.env/
*.yml.lock
# Environments
.env
*.env
.venv/
# Data
*.hic
*.mcool
*.cool
*.bam
*.bw
*.bigwig
*.bed
*.pairs*
*.pt
*.npy
*.csv
*.png
# Editors
.vscode/
.idea/
*.swp
*~
# Jupyter and logs
*.ipynb_checkpoints/
*.log
# Jupyter
.ipynb_checkpoints/
*.ipynb
# OS
.DS_Store
# Results / temp
results/
data/
Thumbs.db

60
CITATION.cff Normal file
View File

@@ -0,0 +1,60 @@
cff-version: 1.2.0
message: >-
If you use this software in your research, please cite it using the
following metadata.
type: software
title: >-
chromatin-gnn: Variational Graph Autoencoder for learning latent
representations of chromatin topology from Hi-C data
authors:
- family-names: Okada
given-names: Toru
alias: ToruOkadaOi
# orcid: "https://orcid.org/XXXX-XXXX-XXXX-XXXX" # add your ORCID
version: "1.0.0"
date-released: "2024-01-01" # update to actual release date
doi: "10.5281/zenodo.XXXXXXX" # replace with actual Zenodo DOI after deposit
repository-code: "https://github.com/ToruOkadaOi/chromatin-gnn"
url: "https://github.com/ToruOkadaOi/chromatin-gnn"
license: MIT
abstract: >-
A Variational Graph Autoencoder (VGAE) applied to Hi-C chromatin contact
data to learn unsupervised latent representations of chromatin topology.
Genomic bins are modelled as graph nodes with ChIP-seq features (CTCF,
H3K27me3); normalised contact frequencies define weighted edges.
The model is trained on GM12878 lymphoblastoid cells and evaluated on
both link-prediction (AUROC/AP) and the biological interpretability of
the latent space against known A/B compartments.
keywords:
- chromatin
- Hi-C
- graph neural network
- variational autoencoder
- VGAE
- A/B compartments
- topologically associating domains
- TAD
- epigenomics
- 3D genome organisation
references:
- type: article
title: >-
A 3D Map of the Human Genome at Kilobase Resolution Reveals
Principles of Chromatin Looping
authors:
- family-names: Rao
given-names: "Suhas S. P."
- family-names: Huntley
given-names: "Miriam H."
year: 2014
journal: Cell
doi: 10.1016/j.cell.2014.11.021
- type: article
title: "Variational Graph Auto-Encoders"
authors:
- family-names: Kipf
given-names: "Thomas N."
- family-names: Welling
given-names: Max
year: 2016
url: "https://arxiv.org/abs/1611.07308"

309
README.md
View File

@@ -1,19 +1,302 @@
# Graph learning for genome architecture
# chromatin-gnn
### workflow
**Variational Graph Autoencoder for learning latent representations of chromatin topology from Hi-C data**
[![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.XXXXXXX.svg)](https://doi.org/10.5281/zenodo.XXXXXXX)
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](LICENSE)
---
## Overview
The three-dimensional organisation of chromatin in the nucleus is not random. Chromosomes fold into compartments, topologically associating domains (TADs), and loop structures that correlate strongly with gene regulation. Hi-C sequencing captures these contacts genome-wide, but the resulting data are high-dimensional and require principled dimensionality reduction to extract biologically interpretable structure.
This repository applies a **Variational Graph Autoencoder (VGAE)** to Hi-C contact data to learn a compact, continuous latent representation of chromatin topology. Genomic bins are treated as graph nodes; normalised contact frequencies define weighted edges; ChIP-seq tracks for CTCF and H3K27me3 supply node features. The model is trained end-to-end on a link-prediction objective and evaluated for its ability to recover known biological structure — A/B compartments — in an entirely unsupervised manner.
---
## Scientific question
> Can a VGAE learn biologically meaningful latent representations of chromatin topology — capturing A/B compartments and cell-type-specific reorganisation — from Hi-C contact data alone, in an unsupervised manner?
---
## Architecture
```
Node features (2D: CTCF, H3K27me3)
BatchNorm
GCNConv(64) ← shared message-passing layer
ReLU + Dropout(0.2)
/ \
GCNConv(32) GCNConv(32)
μ log σ
\ /
Reparameterisation
z ∈ ℝ³² (node embeddings)
Inner-product decoder
(link prediction objective: binary cross-entropy + KL divergence)
```
The encoder is a two-layer Graph Convolutional Network (Kipf & Welling 2016, 2017) with a BatchNorm input layer. The decoder is the standard dot-product decoder used in the original VGAE paper. Training uses a link-prediction objective: the model is asked to distinguish real Hi-C contacts from randomly sampled non-contacts.
---
## Dataset
All data are from the GRCh38/hg38 reference genome, chromosome 21 at 25 kb resolution.
| File | Cell line | Type | Source | Accession |
|------|-----------|------|--------|-----------|
| GM12878.mcool | GM12878 (lymphoblastoid) | Hi-C contact matrix | 4DN Data Portal | 4DNFIRUMEC32 |
| IMR90.mcool | IMR-90 (lung fibroblast) | Hi-C contact matrix | 4DN Data Portal | 4DNFIABB3FHQ |
| GM12878_CTCF.bw | GM12878 | CTCF ChIP-seq (FC/control) | ENCODE | ENCFF741BAQ (exp. ENCSR000AKB) |
| GM12878_H3K27me3.bw | GM12878 | H3K27me3 ChIP-seq (FC/control) | ENCODE | ENCFF736CNQ (exp. ENCSR000AKD) |
| IMR90_CTCF.bw | IMR-90 | CTCF ChIP-seq (FC/control) | ENCODE | ENCFF770DUD (exp. ENCSR000EFI) |
| IMR90_H3K27me3.bw | IMR-90 | H3K27me3 ChIP-seq (FC/control) | ENCODE | ENCFF158HZL (exp. ENCSR431UUY) |
**Graph statistics:**
| Cell line | Bins (chr21, 25 kb) | Edges (contacts) | Node features |
|-----------|---------------------|------------------|---------------|
| GM12878 | 1,869 | 87,557 | 2 (CTCF, H3K27me3) |
| IMR90 | 1,869 | 136,121 | 2 (CTCF, H3K27me3) |
IMR90 has ~55% more intra-chromosomal contacts than GM12878 at chr21, suggesting a more compact or contact-rich chromatin organisation in this fibroblast cell line.
---
## Installation
```bash
# Build graph (contact matrix & bigwig needed)
python scripts/build_graph.py --mcool x.mcool --chrom chrx --res x
--bigwigs xCTCFx.bw xH3K27me3x.bw --out data/chrx_xconditionx.pt
conda create -n chromatin_gnn python=3.10 -y
conda activate chromatin_gnn
# Train model
python scripts/train_vgae.py --graph data/chrx_xconditionx.pt --epochs 100 --outdir results
# CPU-only PyTorch (replace URL for GPU builds)
pip install torch==2.1.2 --index-url https://download.pytorch.org/whl/cpu
# Encode treatment graph
python scripts/encode_graph.py --model results/model.pt --graph data/chrx_xtreatmentx.pt --out results/emb_xtreatmentx.npy
# Compare embeddings
python scripts/compare_embeddings_general.py --emb1 results/emb.npy --emb2 results/emb_xtreatmentx.npy \
--label1 xControlx --label2 xTreatmentx --prefix results/chrx
# All other dependencies
pip install torch-geometric==2.5.3 cooler==0.9.3 pyBigWig pandas \
"numpy>=1.24,<2.0" scikit-learn matplotlib umap-learn scipy seaborn tqdm
```
> **Note:** `cooler==0.9.3` requires `numpy<2.0`. The env.yml captures the exact versions used for this release.
---
## Workflow
```bash
# Full end-to-end run (downloads bigwigs automatically; .mcool files must be present in data/raw/)
bash run_pipeline.sh
# Or run individual steps:
# 1. Build contact graph
python scripts/build_graph.py \
--mcool data/raw/GM12878.mcool \
--chrom chr21 --res 25000 \
--bigwigs data/raw/GM12878_CTCF.bw data/raw/GM12878_H3K27me3.bw \
--out data/processed/GM12878_chr21.pt
# 2. Compute A/B compartments
python scripts/compute_compartments.py \
--mcool data/raw/GM12878.mcool --chrom chr21 --res 25000 \
--bigwig_orient data/raw/GM12878_CTCF.bw \
--out results/GM12878/compartments_chr21.csv
# 3. Train VGAE
python scripts/train_vgae.py \
--graph data/processed/GM12878_chr21.pt \
--epochs 300 --patience 20 --hidden 64 --latent 32 \
--outdir results/GM12878
# 4. Encode a second cell line with the trained model
python scripts/encode_graph.py \
--model results/GM12878/model.pt \
--graph data/processed/IMR90_chr21.pt \
--out results/IMR90/emb.npy
# 5. UMAP visualisation
python scripts/visualize_embeddings.py \
--emb results/GM12878/emb.npy results/IMR90/emb.npy \
--labels GM12878 IMR90 \
--compartments results/GM12878/compartments_chr21.csv \
results/IMR90/compartments_chr21.csv \
--prefix results/figures/umap
# 6. Per-bin embedding comparison
python scripts/compare_embeddings.py \
--emb1 results/GM12878/emb.npy --emb2 results/IMR90/emb.npy \
--label1 GM12878 --label2 IMR90 \
--prefix results/figures/chr21
```
---
## Results
### Training (GM12878, chr21, 25 kb)
| Metric | Value |
|--------|-------|
| Epochs to convergence | 31 / 300 (early stopping, patience=20) |
| Validation AUC (link prediction) | 0.774 |
| Test AUC | 0.777 |
| Test AP | 0.759 |
| Latent dimensionality | 32 |
The model converged rapidly, suggesting that the graph structure of chr21 at 25 kb is learnable with a shallow two-layer GCN.
---
### A/B compartment separation in the latent space
The UMAP of GM12878 node embeddings coloured by A/B compartment shows strong, clean separation of the two compartment types without the model ever receiving compartment labels during training.
| Cell line | Silhouette score (A/B, cosine) | A bins | B bins | Masked (N) |
|-----------|-------------------------------|--------|--------|------------|
| GM12878 (training) | **0.775** | 602 | 683 | 584 |
| IMR90 (zero-shot) | 0.443 | 614 | 709 | 546 |
The GM12878 silhouette of **0.775** indicates that the VGAE has learned a latent space in which A and B compartments are nearly linearly separable — a strong signal given that compartment identity was never provided as a training label.
For IMR90, encoded zero-shot with the GM12878-trained model, the silhouette drops to **0.443**. This is expected: the model's BatchNorm statistics were fit to GM12878, and IMR90's chromatin organisation partially diverges.
**Figures:**
| Figure | Description |
|--------|-------------|
| `results/figures/umap_GM12878_compartment.png` | GM12878 UMAP coloured by A/B compartment |
| `results/figures/umap_GM12878_position.png` | GM12878 UMAP coloured by genomic position |
| `results/figures/umap_IMR90_compartment.png` | IMR90 UMAP coloured by A/B compartment |
| `results/figures/umap_joint.png` | Joint UMAP of both cell lines |
| `results/figures/chr21_delta.png` | Per-bin cosine distance track (GM12878 vs IMR90) |
---
### Cell-type comparison: GM12878 vs IMR90
The IMR90 graph was encoded with the GM12878-trained model (zero-shot transfer). Per-bin cosine distances between the two embedding matrices reveal which genomic loci undergo the largest chromatin reorganisation between cell types.
**Per-bin cosine distance summary:**
| Statistic | Value |
|-----------|-------|
| Mean | 0.245 |
| Median | 0.028 |
| Bins with distance < 0.1 (stable) | 968 / 1,869 (52%) |
| Bins with distance > 0.5 (high shift) | 293 / 1,869 (16%) |
**Mean cosine distance by GM12878 compartment:**
| Compartment | Mean distance | Median distance |
|-------------|--------------|-----------------|
| A (active) | 0.042 | 0.001 |
| B (repressive) | 0.451 | 0.352 |
| N (masked) | 0.213 | 0.000 |
**Key finding:** A-compartment bins are nearly invariant between the two cell types (mean Δ = 0.042), while B-compartment bins shift substantially (mean Δ = 0.451). This is consistent with the known biology: constitutively active chromatin domains tend to be conserved across cell types, while heterochromatic B-compartment organisation is more cell-type-specific.
**Compartment switches (GM12878 → IMR90):**
| GM12878 | IMR90 | Bins | Interpretation |
|---------|-------|------|----------------|
| A | A | 493 | Stable active |
| B | B | 581 | Stable repressive |
| B → A | A | 101 | Loci that open in IMR90 |
| A → B | B | 70 | Loci that close in IMR90 |
101 loci switch from B (repressive in GM12878) to A (active in IMR90), versus 70 in the reverse direction. This asymmetry suggests that IMR90 fibroblasts activate more lineage-specific loci on chr21 than GM12878 lymphoblastoid cells.
---
## Biological interpretation
The results demonstrate that a VGAE trained on Hi-C data **recovers A/B compartment structure without supervision** (silhouette = 0.775). The latent space organises chr21 bins according to their chromatin state, not just their linear genomic position — the UMAP coloured by genomic position shows a broadly continuous gradient, while the compartment-coloured UMAP shows discrete clusters.
The zero-shot application to IMR90 captures the partial conservation of compartment organisation across cell types. The B-compartment instability revealed by the cosine distance analysis is consistent with the literature: heterochromatin rewiring is a known driver of cell-type identity (Lieberman-Aiden et al. 2009; Dixon et al. 2015).
Notably, the model achieves this with only two node features (CTCF and H3K27me3 signal), demonstrating that a modest set of epigenomic marks, combined with contact topology, is sufficient to encode the major axis of chromatin organisation.
---
## Limitations
1. **Single chromosome, single resolution.** Results are for chr21 at 25 kb only. Chr21 is acrocentric with a large masked pericentromeric region (584 / 1,869 bins masked in GM12878), which may reduce statistical power compared to gene-rich autosomes.
2. **Shallow encoder.** The two-layer GCN has a local receptive field (2-hop neighbourhood). Long-range chromatin interactions spanning multiple TADs are not directly encoded. Deeper networks or attention-based architectures may capture these better.
3. **Link-prediction objective ≠ compartment recovery.** The model is optimised to predict contacts, not compartments. The strong silhouette score is emergent, not guaranteed. The objective could be supplemented with biologically-informed losses.
4. **Zero-shot transfer with fixed BatchNorm.** Encoding IMR90 with GM12878 BatchNorm statistics means the model sees IMR90 features in GM12878's normalisation frame. A domain-adaptation approach (e.g., re-fitting BatchNorm on IMR90 with frozen GCN weights) would give a fairer comparison.
5. **Compartment calling is approximate.** The O/E → Pearson correlation → PC1 pipeline is sensitive to the choice of orientation signal (CTCF here). Bins with low coverage are masked and assigned no compartment label, which affects the silhouette calculation.
6. **No TAD-level evaluation.** TAD boundary detection would require a graph-level or boundary-aware metric. The current evaluation is node-level only.
7. **No statistical significance testing.** The compartment switch counts (101 B→A, 70 A→B) have not been tested for significance against a null model (e.g., random permutation of compartment labels).
---
## Future work
- Apply to all autosomes and compare genome-wide compartment recovery.
- Add a TAD-boundary evaluation metric (e.g., insulation score correlation with latent space gradients).
- Fine-tune on IMR90 (transfer learning) to improve the IMR90 silhouette score.
- Add cohesin depletion or auxin-inducible degron (AID) perturbation data as a controlled condition comparison.
- Replace the inner-product decoder with a distance-aware decoder that incorporates linear genomic distance.
- Benchmark against PCA/UMAP of the raw contact matrix and against other graph-based methods (GraphSAGE, GAT).
- Extend node features to include additional histone marks (H3K4me3, H3K27ac, H3K9me3) to test whether richer epigenomic context improves compartment recovery.
---
## Scripts
| Script | Purpose |
|--------|---------|
| `scripts/build_graph.py` | Convert .mcool + bigWigs → PyG Data object |
| `scripts/train_vgae.py` | Train VGAE with link-prediction objective + early stopping |
| `scripts/encode_graph.py` | Encode a new graph with a trained model |
| `scripts/compute_compartments.py` | O/E matrix → Pearson PCA → A/B compartment calls |
| `scripts/visualize_embeddings.py` | UMAP + compartment visualisation + silhouette |
| `scripts/compare_embeddings.py` | Per-bin cosine/L2/L1 distance between two embeddings |
| `scripts/model.py` | Shared Encoder class (imported by train and encode scripts) |
---
## Citation
If you use this code or data in your research, please cite:
```bibtex
@software{okada_chromatin_gnn_2024,
author = {Okada, Toru},
title = {{chromatin-gnn: Variational Graph Autoencoder for learning
latent representations of chromatin topology from Hi-C data}},
year = {2024},
doi = {10.5281/zenodo.XXXXXXX},
url = {https://github.com/ToruOkadaOi/chromatin-gnn},
version = {1.0.0}
}
```
See also `CITATION.cff` for CFF-format metadata.
**Key references:**
- Kipf, T. N. & Welling, M. (2016). Variational Graph Auto-Encoders. *arXiv:1611.07308*
- Kipf, T. N. & Welling, M. (2017). Semi-Supervised Classification with Graph Convolutional Networks. *ICLR 2017*
- Rao, S. S. P. et al. (2014). A 3D Map of the Human Genome at Kilobase Resolution Reveals Principles of Chromatin Looping. *Cell*, 159(7), 16651680. https://doi.org/10.1016/j.cell.2014.11.021
- Lieberman-Aiden, E. et al. (2009). Comprehensive mapping of long-range interactions reveals folding principles of the human genome. *Science*, 326(5950), 289293.
---
## License
MIT — see [LICENSE](LICENSE).

Binary file not shown.

Binary file not shown.

36
env.yml
View File

@@ -1,21 +1,27 @@
name: chromatin_gnn_aman
name: chromatin_gnn
# Installation (pip-based, conda used only for Python):
# conda create -n chromatin_gnn python=3.10 -y
# conda activate chromatin_gnn
# pip install torch==2.1.2 --index-url https://download.pytorch.org/whl/cpu
# pip install -r requirements.txt
#
# For GPU support replace the torch line with:
# pip install torch==2.1.2 --index-url https://download.pytorch.org/whl/cu118
channels:
- pytorch
- conda-forge
- defaults
dependencies:
- python=3.10
- pytorch
- torchvision
- torchaudio
- cooler
- pybigwig
- pandas
- numpy
- scikit-learn
- matplotlib
- umap-learn
- pip
- pip:
- torch-geometric
- torch==2.1.2 # CPU build; see GPU note above
- torch-geometric==2.5.3
- cooler==0.9.3
- pyBigWig==0.3.25
- pandas==2.1.4
- "numpy>=1.24,<2.0" # cooler 0.9.3 requires numpy<2.0
- scikit-learn==1.4.2
- matplotlib==3.8.4
- umap-learn==0.5.12
- scipy==1.12.0
- seaborn==0.13.2
- tqdm==4.66.2

File diff suppressed because it is too large Load Diff

BIN
results/GM12878/emb.npy Normal file

Binary file not shown.

View File

@@ -0,0 +1,13 @@
{
"val_auc": 0.774052272942853,
"test_auc": 0.7767560244060842,
"test_ap": 0.7585872967832136,
"epochs_ran": 31,
"epochs_max": 300,
"patience": 20,
"hidden": 64,
"latent": 32,
"dropout": 0.2,
"lr": 0.001,
"seed": 42
}

BIN
results/GM12878/model.pt Normal file

Binary file not shown.

File diff suppressed because it is too large Load Diff

BIN
results/IMR90/emb.npy Normal file

Binary file not shown.

File diff suppressed because it is too large Load Diff

Binary file not shown.

After

Width:  |  Height:  |  Size: 219 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 189 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 238 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 204 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 267 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 304 KiB

View File

@@ -0,0 +1,3 @@
label,n_bins,latent_dim,mean_embedding_norm,std_embedding_values,silhouette_AB_cosine
GM12878,1869,32,0.6679670810699463,0.1371903121471405,0.7748898267745972
IMR90,1869,32,0.7111130356788635,0.14772675931453705,0.4431541860103607
1 label n_bins latent_dim mean_embedding_norm std_embedding_values silhouette_AB_cosine
2 GM12878 1869 32 0.6679670810699463 0.1371903121471405 0.7748898267745972
3 IMR90 1869 32 0.7111130356788635 0.14772675931453705 0.4431541860103607

162
run_pipeline.sh Normal file
View File

@@ -0,0 +1,162 @@
#!/usr/bin/env bash
#
# run_pipeline.sh end-to-end pipeline for chromatin-gnn
#
# Prerequisites
# -------------
# 1. Create and activate the environment:
# conda create -n chromatin_gnn python=3.10 -y
# conda activate chromatin_gnn
# pip install torch==2.1.2 --index-url https://download.pytorch.org/whl/cpu
# pip install torch-geometric==2.5.3 cooler==0.9.3 pyBigWig pandas \
# "numpy>=1.24,<2.0" scikit-learn matplotlib umap-learn scipy seaborn tqdm
#
# 2. Download raw data into data/raw/ (see README.md § Dataset for URLs):
# GM12878.mcool 4DN accession 4DNFIRUMEC32
# IMR90.mcool 4DN accession 4DNFIABB3FHQ
# GM12878_CTCF.bw ENCODE ENCFF741BAQ (experiment ENCSR000AKB)
# GM12878_H3K27me3.bw ENCODE ENCFF736CNQ (experiment ENCSR000AKD)
# IMR90_CTCF.bw ENCODE ENCFF770DUD (experiment ENCSR000EFI)
# IMR90_H3K27me3.bw ENCODE ENCFF158HZL (experiment ENCSR431UUY)
#
# Usage
# -----
# bash run_pipeline.sh [--chrom chr21] [--res 25000] [--epochs 300]
set -euo pipefail
# ========== Configuration ==========
CHROM="${CHROM:-chr21}"
RES="${RES:-25000}"
EPOCHS="${EPOCHS:-300}"
PATIENCE="${PATIENCE:-20}"
HIDDEN="${HIDDEN:-64}"
LATENT="${LATENT:-32}"
SEED="${SEED:-42}"
REPO="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
SCRIPTS="$REPO/scripts"
DATA="$REPO/data"
RESULTS="$REPO/results"
# ========== Directories ==========
mkdir -p "$DATA/raw" "$DATA/processed" \
"$RESULTS/GM12878" "$RESULTS/IMR90" "$RESULTS/figures"
# ========== Download ENCODE bigWig tracks ==========
echo "=== Downloading ENCODE bigWig tracks ==="
for entry in \
"GM12878_CTCF.bw|https://www.encodeproject.org/files/ENCFF741BAQ/@@download/ENCFF741BAQ.bigWig" \
"GM12878_H3K27me3.bw|https://www.encodeproject.org/files/ENCFF736CNQ/@@download/ENCFF736CNQ.bigWig" \
"IMR90_CTCF.bw|https://www.encodeproject.org/files/ENCFF770DUD/@@download/ENCFF770DUD.bigWig" \
"IMR90_H3K27me3.bw|https://www.encodeproject.org/files/ENCFF158HZL/@@download/ENCFF158HZL.bigWig"
do
fname="${entry%%|*}"
url="${entry##*|}"
out="$DATA/raw/$fname"
if [ -f "$out" ]; then
echo " $fname already present, skipping"
else
echo " Downloading $fname ..."
wget -q --show-progress -O "$out" "$url"
fi
done
# .mcool files must be downloaded manually from 4DN (requires free account):
# GM12878: https://data.4dnucleome.org/files-processed/4DNFIRUMEC32/@@download/4DNFIRUMEC32.mcool
# IMR90: https://data.4dnucleome.org/files-processed/4DNFIABB3FHQ/@@download/4DNFIABB3FHQ.mcool
for f in GM12878.mcool IMR90.mcool; do
if [ ! -f "$DATA/raw/$f" ]; then
echo "ERROR: $DATA/raw/$f not found. Download from 4DN (see README) and retry." >&2
exit 1
fi
done
# ========== Step 1: Build contact graphs ==========
echo ""
echo "=== Step 1: Building chromatin contact graphs ==="
for CELL in GM12878 IMR90; do
OUT="$DATA/processed/${CELL}_${CHROM}.pt"
if [ -f "$OUT" ]; then
echo " ${CELL} graph already exists, skipping"
else
python "$SCRIPTS/build_graph.py" \
--mcool "$DATA/raw/${CELL}.mcool" \
--chrom "$CHROM" --res "$RES" \
--bigwigs "$DATA/raw/${CELL}_CTCF.bw" "$DATA/raw/${CELL}_H3K27me3.bw" \
--out "$OUT"
fi
done
# ========== Step 2: Compute A/B compartments ==========
echo ""
echo "=== Step 2: Computing A/B compartments (PC1 of O/E Pearson correlation) ==="
for CELL in GM12878 IMR90; do
OUT="$RESULTS/${CELL}/compartments_${CHROM}.csv"
if [ -f "$OUT" ]; then
echo " ${CELL} compartments already exist, skipping"
else
python "$SCRIPTS/compute_compartments.py" \
--mcool "$DATA/raw/${CELL}.mcool" \
--chrom "$CHROM" --res "$RES" \
--bigwig_orient "$DATA/raw/${CELL}_CTCF.bw" \
--out "$OUT"
fi
done
# ========== Step 3: Train VGAE on GM12878 ==========
echo ""
echo "=== Step 3: Training VGAE on GM12878 ==="
if [ -f "$RESULTS/GM12878/model.pt" ]; then
echo " Trained model already exists, skipping"
else
python "$SCRIPTS/train_vgae.py" \
--graph "$DATA/processed/GM12878_${CHROM}.pt" \
--epochs "$EPOCHS" --patience "$PATIENCE" \
--hidden "$HIDDEN" --latent "$LATENT" \
--seed "$SEED" \
--outdir "$RESULTS/GM12878"
fi
# ========== Step 4: Encode IMR90 with GM12878 model ==========
echo ""
echo "=== Step 4: Encoding IMR90 graph with trained GM12878 model ==="
if [ -f "$RESULTS/IMR90/emb.npy" ]; then
echo " IMR90 embeddings already exist, skipping"
else
python "$SCRIPTS/encode_graph.py" \
--model "$RESULTS/GM12878/model.pt" \
--graph "$DATA/processed/IMR90_${CHROM}.pt" \
--out "$RESULTS/IMR90/emb.npy"
fi
# ========== Step 5: Visualise embeddings ==========
echo ""
echo "=== Step 5: Generating UMAP visualisations ==="
python "$SCRIPTS/visualize_embeddings.py" \
--emb "$RESULTS/GM12878/emb.npy" "$RESULTS/IMR90/emb.npy" \
--labels GM12878 IMR90 \
--compartments \
"$RESULTS/GM12878/compartments_${CHROM}.csv" \
"$RESULTS/IMR90/compartments_${CHROM}.csv" \
--prefix "$RESULTS/figures/umap" \
--seed "$SEED"
# ========== Step 6: Compare embeddings ==========
echo ""
echo "=== Step 6: Comparing GM12878 vs IMR90 embeddings ==="
python "$SCRIPTS/compare_embeddings.py" \
--emb1 "$RESULTS/GM12878/emb.npy" \
--emb2 "$RESULTS/IMR90/emb.npy" \
--label1 GM12878 --label2 IMR90 \
--prefix "$RESULTS/figures/${CHROM}"
# ========== Summary ==========
echo ""
echo "=== Pipeline complete ==="
echo "Outputs:"
echo " Model + embeddings : $RESULTS/GM12878/"
echo " Figures : $RESULTS/figures/"
echo " Metrics : $RESULTS/GM12878/metrics.json"
echo ""
cat "$RESULTS/GM12878/metrics.json"

View File

@@ -50,7 +50,8 @@ def main():
if emb1.shape != emb2.shape:
raise ValueError(f"Shape mismatch: {emb1.shape} vs {emb2.shape}")
os.makedirs(os.path.dirname(args.prefix), exist_ok=True)
prefix_dir = os.path.dirname(os.path.abspath(args.prefix))
os.makedirs(prefix_dir, exist_ok=True)
n_bins, n_dim = emb1.shape
print(f"Loaded embeddings: {n_bins} bins × {n_dim} dims")

View File

@@ -0,0 +1,170 @@
#!/usr/bin/env python3
"""
Compute A/B chromatin compartments from a Hi-C .mcool file.
Algorithm
---------
1. Load ICE-balanced contact matrix for the target chromosome.
2. Distance-normalise to O/E (divide each diagonal by its mean contact frequency).
3. Compute Pearson correlation matrix of the O/E rows.
4. PCA of the correlation matrix; PC1 distinguishes A from B compartments.
5. Orient the PC1 sign using --bigwig_orient (e.g. CTCF):
positive PC1 → high signal in that track.
With CTCF: positive PC1 = CTCF-enriched = A compartment (active).
With H3K27me3: pass --flip_orient so positive PC1 = B compartment (repressive).
Output
------
CSV with columns: chrom, start, end, pc1, compartment (A / B / N for masked bins).
"""
import argparse
import os
import sys
import cooler
import numpy as np
import pandas as pd
from sklearn.decomposition import PCA
sys.path.insert(0, os.path.dirname(__file__))
def _bin_bigwig(bw_path: str, chrom: str, bins) -> np.ndarray:
"""Average bigWig signal over a list of (start, end) genomic bins."""
import pyBigWig
bw = pyBigWig.open(bw_path)
chrom_len = bw.chroms().get(chrom, 0)
vals = []
for s, e in bins:
s, e = max(0, int(s)), min(chrom_len, int(e))
if s >= e:
vals.append(0.0)
continue
v = bw.stats(chrom, s, e, type="mean")[0]
vals.append(0.0 if v is None or np.isnan(v) else float(v))
bw.close()
return np.array(vals)
def _observed_over_expected(matrix: np.ndarray) -> np.ndarray:
"""Distance-normalise a symmetric contact matrix (O/E transform)."""
n = matrix.shape[0]
oe = np.zeros((n, n), dtype=float)
for d in range(n):
idx = np.arange(n - d)
diag = matrix[idx, idx + d].astype(float)
positive = diag[diag > 0]
if positive.size == 0:
continue
mean_d = positive.mean()
norm_diag = np.where((np.isnan(diag)) | (diag == 0), 0.0, diag / mean_d)
oe[idx, idx + d] = norm_diag
if d > 0:
oe[idx + d, idx] = norm_diag
return oe
def compute_compartments(
mcool_path: str,
chrom: str,
res: int,
orient_signal=None,
flip_orient: bool = False,
) -> pd.DataFrame:
"""
Return a DataFrame (chrom, start, end, pc1, compartment).
Parameters
----------
orient_signal : array-like, optional
Per-bin 1-D signal used to fix the sign of PC1.
Pass CTCF signal for positive-PC1 = A convention.
flip_orient : bool
If True, high orient_signal maps to negative PC1 (use with H3K27me3).
"""
c = cooler.Cooler(f"{mcool_path}::resolutions/{res}")
bins_df = c.bins().fetch(chrom).reset_index(drop=True)
matrix = c.matrix(balance=True).fetch(chrom).astype(float)
bad_bins = np.isnan(matrix).all(axis=0) | (matrix.sum(axis=0) == 0)
np.nan_to_num(matrix, nan=0.0, copy=False)
oe = _observed_over_expected(matrix)
good = ~bad_bins
oe_good = oe[np.ix_(good, good)]
# Zero rows produce NaN in corrcoef; add tiny noise to avoid singularity
row_norms = np.linalg.norm(oe_good, axis=1)
oe_good[row_norms == 0] += 1e-9
corr = np.corrcoef(oe_good)
np.nan_to_num(corr, nan=0.0, copy=False)
pca = PCA(n_components=3, random_state=42)
pcs = pca.fit_transform(corr)
pc1_good = pcs[:, 0]
pc1 = np.full(len(bins_df), np.nan)
pc1[good] = pc1_good
if orient_signal is not None:
sig = np.asarray(orient_signal, dtype=float)
sig_good = sig[good]
valid = ~np.isnan(sig_good) & ~np.isnan(pc1_good)
if valid.sum() > 10:
r = np.corrcoef(pc1_good[valid], sig_good[valid])[0, 1]
# By default: positive orient_signal → positive PC1.
# flip_orient reverses this (e.g. H3K27me3 → positive PC1 = B).
if (r < 0 and not flip_orient) or (r > 0 and flip_orient):
pc1 = -pc1
bins_df["pc1"] = pc1
bins_df["compartment"] = np.where(
np.isnan(bins_df["pc1"]), "N",
np.where(bins_df["pc1"] > 0, "A", "B"),
)
return bins_df[["chrom", "start", "end", "pc1", "compartment"]]
def main():
p = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
)
p.add_argument("--mcool", required=True, help="Path to .mcool file")
p.add_argument("--chrom", required=True, help="Chromosome (e.g. chr21)")
p.add_argument("--res", type=int, default=25000, help="Resolution in bp (default: 25000)")
p.add_argument("--bigwig_orient",
help="bigWig track for PC1 sign orientation (recommended: CTCF)")
p.add_argument("--flip_orient", action="store_true",
help="Flip orientation: high signal → negative PC1 (use with H3K27me3)")
p.add_argument("--out", required=True, help="Output CSV path")
args = p.parse_args()
orient_signal = None
if args.bigwig_orient:
c = cooler.Cooler(f"{args.mcool}::resolutions/{args.res}")
bins_df = c.bins().fetch(args.chrom).reset_index(drop=True)
coords = list(zip(bins_df["start"].values, bins_df["end"].values))
orient_signal = _bin_bigwig(args.bigwig_orient, args.chrom, coords)
print(f"Loaded orientation signal: {os.path.basename(args.bigwig_orient)}")
df = compute_compartments(
args.mcool, args.chrom, args.res,
orient_signal=orient_signal,
flip_orient=args.flip_orient,
)
os.makedirs(os.path.dirname(os.path.abspath(args.out)), exist_ok=True)
df.to_csv(args.out, index=False)
n_a = (df["compartment"] == "A").sum()
n_b = (df["compartment"] == "B").sum()
n_nan = (df["compartment"] == "N").sum()
print(f"Saved → {args.out}")
print(f" A: {n_a} B: {n_b} N/masked: {n_nan} bins")
if __name__ == "__main__":
main()

View File

@@ -1,60 +1,88 @@
#!/usr/bin/env python3
"""
Encode a new graph using a trained VGAE model.
Automatically infers hidden/latent dimensions from saved weights.
Encode a chromatin contact graph using a trained VGAE model.
Dimensions (in_dim, hidden, latent) are inferred automatically from the saved
state_dict. The BatchNorm running statistics from training are restored, so the
same normalisation is applied to held-out cell lines without a separate scaler.
Usage
-----
python scripts/encode_graph.py \\
--model results/GM12878/model.pt \\
--graph data/processed/IMR90_chr21.pt \\
--out results/IMR90/emb.npy
"""
import argparse, torch, numpy as np
from torch_geometric.nn import GCNConv, VGAE
import argparse
import os
import sys
# Reuse your Encoder definition directly here for clarity
class Encoder(torch.nn.Module):
def __init__(self, in_dim, hidden, latent, dropout=0.2):
super().__init__()
self.gc1 = GCNConv(in_dim, hidden)
self.gc_mu = GCNConv(hidden, latent)
self.gc_log = GCNConv(hidden, latent)
self.dropout = dropout
import numpy as np
import torch
from torch_geometric.nn.models import VGAE
def forward(self, x, edge_index):
import torch.nn.functional as F
h = self.gc1(x, edge_index)
h = F.relu(h)
h = F.dropout(h, p=self.dropout, training=self.training)
return self.gc_mu(h, edge_index), self.gc_log(h, edge_index)
sys.path.insert(0, os.path.dirname(__file__))
from model import Encoder
def _infer_dims(state_dict: dict) -> tuple:
"""Infer (in_dim, hidden, latent) from a VGAE state_dict."""
keys = list(state_dict.keys())
def _first_weight(substr):
for k in keys:
if (substr in k
and "weight" in k
and "running" not in k
and "num_batches" not in k):
return state_dict[k].shape
raise KeyError(f"No weight key containing '{substr}' in state_dict. "
f"Available keys: {keys}")
gc1_shape = _first_weight("gc1") # shape [hidden, in_dim]
gc_mu_shape = _first_weight("gc_mu") # shape [latent, hidden]
hidden = gc1_shape[0]
latent = gc_mu_shape[0]
# in_dim from BatchNorm weight (shape [in_dim])
for k in keys:
if "norm" in k and k.endswith("weight") and "running" not in k:
in_dim = state_dict[k].shape[0]
break
else:
in_dim = gc1_shape[1] # fallback: second dim of gc1 weight
return in_dim, hidden, latent
def main():
p = argparse.ArgumentParser()
p.add_argument("--model", required=True)
p.add_argument("--graph", required=True)
p.add_argument("--out", required=True)
p = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
)
p.add_argument("--model", required=True,
help="Path to model.pt saved by train_vgae.py")
p.add_argument("--graph", required=True,
help="Path to Data .pt file from build_graph.py")
p.add_argument("--out", required=True,
help="Output .npy path for node embeddings")
args = p.parse_args()
# ---- Load data and model state ----
data = torch.load(args.graph)
model_state = torch.load(args.model, map_location="cpu")
data = torch.load(args.graph, weights_only=False)
state_dict = torch.load(args.model, map_location="cpu", weights_only=False)
# ---- Infer dimensions dynamically ----
in_dim = data.x.size(1)
# detect hidden and latent dimensions safely
keys = list(model_state.keys())
gc1_weight = [k for k in keys if "gc1" in k and "weight" in k][0]
gc_mu_weight = [k for k in keys if "gc_mu" in k and "weight" in k][0]
hidden = model_state[gc1_weight].shape[0]
latent = model_state[gc_mu_weight].shape[0]
print(f"Inferred dims: in={in_dim}, hidden={hidden}, latent={latent}")
in_dim, hidden, latent = _infer_dims(state_dict)
print(f"Inferred: in_dim={in_dim} hidden={hidden} latent={latent}")
enc = Encoder(in_dim=in_dim, hidden=hidden, latent=latent)
model = VGAE(enc)
model.load_state_dict(model_state)
model.load_state_dict(state_dict)
model.eval()
# ---- Encode ----
with torch.no_grad():
z = model.encode(data.x.float(), data.edge_index)
os.makedirs(os.path.dirname(os.path.abspath(args.out)), exist_ok=True)
np.save(args.out, z.cpu().numpy())
print(f"Saved embeddings → {args.out} shape={z.shape}")

32
scripts/model.py Normal file
View File

@@ -0,0 +1,32 @@
#!/usr/bin/env python3
"""Shared VGAE encoder. Imported by train_vgae.py and encode_graph.py."""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
class Encoder(nn.Module):
"""Two-layer GCN encoder for VGAE with input BatchNorm.
Architecture: BatchNorm → GCN(hidden) → ReLU → Dropout → GCN_mu / GCN_logstd
The BatchNorm layer normalises raw ChIP-seq signals and its running statistics
are saved in model.pt, so encode_graph.py applies identical normalisation to
held-out cell lines without a separate scaler file.
"""
def __init__(self, in_dim: int, hidden: int, latent: int, dropout: float = 0.2):
super().__init__()
self.norm = nn.BatchNorm1d(in_dim)
self.gc1 = GCNConv(in_dim, hidden, add_self_loops=True, normalize=True)
self.gc_mu = GCNConv(hidden, latent, add_self_loops=True, normalize=True)
self.gc_log = GCNConv(hidden, latent, add_self_loops=True, normalize=True)
self.dropout = dropout
def forward(self, x, edge_index):
x = self.norm(x)
h = F.relu(self.gc1(x, edge_index))
h = F.dropout(h, p=self.dropout, training=self.training)
return self.gc_mu(h, edge_index), self.gc_log(h, edge_index)

View File

@@ -1,72 +1,66 @@
#!/usr/bin/env python3
"""
Train a Variational Graph Autoencoder (VGAE) on a chromatin contact graph.
---
Inputs:
- A PyTorch Geometric Data object saved with torch.save()
- from build_graph.py
---
Outputs (under results/):
- model.pt : trained VGAE state_dict
- emb.npy : node embeddings (mean; shape [num_nodes, latent_dim])
- metrics.json : train/val/test AUC/AP summary
Inputs
------
PyTorch Geometric Data object saved by build_graph.py.
Outputs (under --outdir)
------------------------
model.pt trained VGAE state_dict (includes BatchNorm running statistics)
emb.npy node embeddings — mu vector, shape [num_nodes, latent_dim]
metrics.json val/test AUC & AP plus all hyperparameters
"""
import os, json, argparse, numpy as np, torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
import argparse
import json
import os
import sys
import numpy as np
import torch
from torch_geometric.nn.models import VGAE
from torch_geometric.transforms import RandomLinkSplit
from torch_geometric.utils import to_undirected, remove_self_loops
from torch_geometric.utils import negative_sampling
from sklearn.metrics import roc_auc_score, average_precision_score
from torch_geometric.utils import (
negative_sampling,
remove_self_loops,
to_undirected,
)
from sklearn.metrics import average_precision_score, roc_auc_score
class Encoder(torch.nn.Module):
def __init__(self, in_dim: int, hidden: int, latent: int, dropout: float = 0.2):
super().__init__()
self.gc1 = GCNConv(in_dim, hidden, add_self_loops=True, normalize=True)
self.gc_mu = GCNConv(hidden, latent, add_self_loops=True, normalize=True)
self.gc_log = GCNConv(hidden, latent, add_self_loops=True, normalize=True)
self.dropout = dropout
def forward(self, x, edge_index):
h = self.gc1(x, edge_index)
h = F.relu(h)
h = F.dropout(h, p=self.dropout, training=self.training)
return self.gc_mu(h, edge_index), self.gc_log(h, edge_index)
sys.path.insert(0, os.path.dirname(__file__))
from model import Encoder
@torch.no_grad()
def eval_linkpred(model, data_like, z):
"""Compute AUROC/AP using provided positive/negative edges."""
pos = data_like.pos_edge_index
neg = data_like.neg_edge_index
# model.test returns (auc, ap) but relies on torchmetrics in some versions;
# compute explicitly for stability:
def sigmoid(x): return 1 / (1 + torch.exp(-x))
def _eval_linkpred(z, pos_edges, neg_edges):
"""Return (AUROC, AP) for link prediction."""
def _sigmoid(x):
return 1.0 / (1.0 + torch.exp(-x))
# Inner product decoder scores
def scores(edges):
def _score(edges):
src, dst = edges
s = (z[src] * z[dst]).sum(dim=1)
return sigmoid(s).cpu().numpy()
return _sigmoid((z[src] * z[dst]).sum(dim=1)).cpu().numpy()
y_true = np.concatenate([np.ones(pos.size(1)), np.zeros(neg.size(1))])
y_pred = np.concatenate([scores(pos), scores(neg)])
auc = roc_auc_score(y_true, y_pred)
ap = average_precision_score(y_true, y_pred)
return auc, ap
y_true = np.concatenate([np.ones(pos_edges.size(1)),
np.zeros(neg_edges.size(1))])
y_pred = np.concatenate([_score(pos_edges), _score(neg_edges)])
return roc_auc_score(y_true, y_pred), average_precision_score(y_true, y_pred)
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--graph", required=True, help="Path to Data .pt file")
ap.add_argument("--epochs", type=int, default=100)
ap = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
)
ap.add_argument("--graph", required=True,
help="Path to Data .pt file from build_graph.py")
ap.add_argument("--epochs", type=int, default=300)
ap.add_argument("--patience", type=int, default=20,
help="Early-stopping patience (val-AUC epochs without improvement)")
ap.add_argument("--lr", type=float, default=1e-3)
ap.add_argument("--hidden", type=int, default=128)
ap.add_argument("--latent", type=int, default=64)
ap.add_argument("--hidden", type=int, default=64)
ap.add_argument("--latent", type=int, default=32)
ap.add_argument("--dropout", type=float, default=0.2)
ap.add_argument("--seed", type=int, default=42)
ap.add_argument("--outdir", default="results")
@@ -76,84 +70,94 @@ def main():
np.random.seed(args.seed)
os.makedirs(args.outdir, exist_ok=True)
# Load graph
data = torch.load(args.graph)
# Coalesce/clean edges
# ---- Load and clean graph ----
data = torch.load(args.graph, weights_only=False)
ei, _ = remove_self_loops(data.edge_index)
data.edge_index = to_undirected(ei, num_nodes=data.num_nodes)
x = data.x.float()
print(f"Graph: {data.num_nodes} nodes "
f"{data.edge_index.shape[1]} edges "
f"{x.shape[1]} node features")
# Split edges for link prediction
# ---- Edge splits for link-prediction evaluation ----
splitter = RandomLinkSplit(
num_val=0.1,
num_test=0.1,
num_val=0.1, num_test=0.1,
is_undirected=True,
add_negative_train_samples=False,
split_labels=False,
)
train_data, val_data, test_data = splitter(data)
# Positive edges are just the edges in each split
train_data.pos_edge_index = train_data.edge_index
val_data.pos_edge_index = val_data.edge_index
test_data.pos_edge_index = test_data.edge_index
# Generate negative edges for validation and test manually
for subset in [val_data, test_data]:
subset.neg_edge_index = negative_sampling(
edge_index=subset.edge_index,
for split in (val_data, test_data):
split.pos_edge_index = split.edge_index
split.neg_edge_index = negative_sampling(
edge_index=split.edge_index,
num_nodes=data.num_nodes,
num_neg_samples=subset.edge_index.size(1),
method='sparse'
num_neg_samples=split.edge_index.size(1),
method="sparse",
)
# Model
enc = Encoder(in_dim=x.size(1), hidden=args.hidden, latent=args.latent, dropout=args.dropout)
# ---- Model ----
enc = Encoder(in_dim=x.size(1), hidden=args.hidden,
latent=args.latent, dropout=args.dropout)
model = VGAE(enc)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
# Training loop
# ---- Training loop with early stopping ----
best_val_auc = -1.0
best_state = None
no_improve = 0
epochs_ran = 0
for epoch in range(1, args.epochs + 1):
model.train()
optimizer.zero_grad()
# Encode using remaining training edges
z = model.encode(x, train_data.edge_index)
# Reconstruction loss on positive training edges (negatives sampled inside)
loss_recon = model.recon_loss(z, train_data.pos_edge_index)
# KL divergence regularizer
loss_kl = (1.0 / data.num_nodes) * model.kl_loss()
loss = loss_recon + loss_kl
loss = (model.recon_loss(z, train_data.pos_edge_index)
+ (1.0 / data.num_nodes) * model.kl_loss())
loss.backward()
optimizer.step()
# Validation
model.eval()
with torch.no_grad():
z_full = model.encode(x, data.edge_index) # use full graph for eval embeddings
val_auc, val_ap = eval_linkpred(model, val_data, z_full)
z_full = model.encode(x, data.edge_index)
val_auc, val_ap = _eval_linkpred(
z_full, val_data.pos_edge_index, val_data.neg_edge_index
)
if val_auc > best_val_auc:
best_val_auc = val_auc
best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
best_state = {k: v.cpu().clone()
for k, v in model.state_dict().items()}
no_improve = 0
else:
no_improve += 1
epochs_ran = epoch
if epoch % 10 == 0 or epoch == 1:
print(f"[{epoch:03d}/{args.epochs}] loss={loss.item():.4f} | val AUC={val_auc:.4f} AP={val_ap:.4f}")
print(f"[{epoch:03d}/{args.epochs}] "
f"loss={loss.item():.4f} "
f"val AUC={val_auc:.4f} AP={val_ap:.4f}")
# Save best model
if no_improve >= args.patience:
print(f"Early stopping at epoch {epoch} "
f"(no val-AUC improvement for {args.patience} epochs)")
break
# ---- Restore best checkpoint and compute test metrics ----
model.load_state_dict(best_state)
model_path = os.path.join(args.outdir, "model.pt")
torch.save(model.state_dict(), model_path)
# Final test metrics
model.eval()
with torch.no_grad():
z_final = model.encode(x, data.edge_index)
test_auc, test_ap = eval_linkpred(model, test_data, z_final)
test_auc, test_ap = _eval_linkpred(
z_final, test_data.pos_edge_index, test_data.neg_edge_index
)
# ---- Save outputs ----
model_path = os.path.join(args.outdir, "model.pt")
torch.save(best_state, model_path)
# Save embeddings & metrics
emb_path = os.path.join(args.outdir, "emb.npy")
np.save(emb_path, z_final.cpu().numpy())
@@ -161,19 +165,21 @@ def main():
"val_auc": float(best_val_auc),
"test_auc": float(test_auc),
"test_ap": float(test_ap),
"epochs": args.epochs,
"epochs_ran": epochs_ran,
"epochs_max": args.epochs,
"patience": args.patience,
"hidden": args.hidden,
"latent": args.latent,
"dropout": args.dropout,
"lr": args.lr,
"seed": args.seed
"seed": args.seed,
}
with open(os.path.join(args.outdir, "metrics.json"), "w") as f:
json.dump(metrics, f, indent=2)
print(f"Saved model -> {model_path}")
print(f"Saved embeddings -> {emb_path} (shape={z_final.shape})")
print(f"Metrics: AUC(test)={test_auc:.4f}, AP(test)={test_ap:.4f}")
print(f"\nSaved model {model_path}")
print(f"Saved embeddings {emb_path} shape={z_final.shape}")
print(f"Test AUC={test_auc:.4f} AP={test_ap:.4f}")
if __name__ == "__main__":

View File

@@ -0,0 +1,189 @@
#!/usr/bin/env python3
"""
Visualise VGAE node embeddings using UMAP.
Produces (under --prefix):
{prefix}_{label}_position.png UMAP coloured by genomic position (bin index)
{prefix}_{label}_compartment.png UMAP coloured by A/B compartment (needs --compartments)
{prefix}_joint.png Joint UMAP of all supplied cell lines
{prefix}_stats.csv Per-embedding summary statistics
Usage
-----
python scripts/visualize_embeddings.py \\
--emb results/GM12878/emb.npy results/IMR90/emb.npy \\
--labels GM12878 IMR90 \\
--compartments results/GM12878/compartments_chr21.csv \\
results/IMR90/compartments_chr21.csv \\
--prefix results/figures/umap
"""
import argparse
import os
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.metrics import silhouette_score
import umap
COMPARTMENT_COLORS = {"A": "#E41A1C", "B": "#377EB8", "N": "#AAAAAA"}
CELL_LINE_PALETTE = ["#E41A1C", "#4DAF4A", "#984EA3", "#FF7F00", "#377EB8"]
plt.rcParams.update({
"font.family": "sans-serif",
"axes.spines.top": False,
"axes.spines.right": False,
})
def _run_umap(emb: np.ndarray, seed: int = 42) -> np.ndarray:
reducer = umap.UMAP(n_components=2, random_state=seed,
min_dist=0.3, n_neighbors=15)
return reducer.fit_transform(emb)
def _plot_position(coords: np.ndarray, label: str, out_path: str):
fig, ax = plt.subplots(figsize=(6.5, 5.5))
sc = ax.scatter(coords[:, 0], coords[:, 1],
c=np.arange(len(coords)), cmap="plasma",
s=4, alpha=0.75, linewidths=0, rasterized=True)
cbar = plt.colorbar(sc, ax=ax, shrink=0.8, pad=0.02)
cbar.set_label("Bin index (5 → 3)", fontsize=9)
ax.set_title(f"{label} — UMAP coloured by genomic position", fontsize=10)
ax.set_xlabel("UMAP 1", fontsize=9)
ax.set_ylabel("UMAP 2", fontsize=9)
ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
plt.tight_layout()
plt.savefig(out_path, dpi=300)
plt.close()
def _plot_compartment(coords: np.ndarray, compartments: np.ndarray,
label: str, out_path: str):
fig, ax = plt.subplots(figsize=(6.5, 5.5))
for comp in ("A", "B", "N"):
mask = compartments == comp
if mask.sum() == 0:
continue
ax.scatter(coords[mask, 0], coords[mask, 1],
c=COMPARTMENT_COLORS[comp], s=4, alpha=0.75,
label=f"{comp} ({mask.sum()} bins)", linewidths=0,
rasterized=True)
ax.legend(markerscale=3, title="Compartment", fontsize=9,
title_fontsize=9, frameon=False)
ax.set_title(f"{label} — UMAP coloured by A/B compartment", fontsize=10)
ax.set_xlabel("UMAP 1", fontsize=9)
ax.set_ylabel("UMAP 2", fontsize=9)
ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
plt.tight_layout()
plt.savefig(out_path, dpi=300)
plt.close()
def _plot_joint(all_coords: np.ndarray, all_labels: list, out_path: str):
fig, ax = plt.subplots(figsize=(7, 6))
unique = list(dict.fromkeys(all_labels))
arr = np.array(all_labels)
for i, label in enumerate(unique):
mask = arr == label
ax.scatter(all_coords[mask, 0], all_coords[mask, 1],
c=CELL_LINE_PALETTE[i % len(CELL_LINE_PALETTE)],
s=3, alpha=0.6, label=label, linewidths=0, rasterized=True)
ax.legend(markerscale=4, title="Cell line", fontsize=9,
title_fontsize=9, frameon=False)
ax.set_title("Joint UMAP — chromatin topology embeddings", fontsize=11)
ax.set_xlabel("UMAP 1", fontsize=9)
ax.set_ylabel("UMAP 2", fontsize=9)
ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
plt.tight_layout()
plt.savefig(out_path, dpi=300)
plt.close()
def _silhouette(emb: np.ndarray, compartments: np.ndarray) -> float:
mask = compartments != "N"
if mask.sum() < 20 or len(set(compartments[mask])) < 2:
return float("nan")
try:
return float(silhouette_score(emb[mask], compartments[mask], metric="cosine"))
except Exception:
return float("nan")
def main():
p = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
)
p.add_argument("--emb", nargs="+", required=True,
help="One or more .npy embedding files")
p.add_argument("--labels", nargs="+", required=True,
help="Label for each embedding (same order)")
p.add_argument("--compartments", nargs="+",
help="Compartment CSV files, one per embedding (optional)")
p.add_argument("--prefix", default="results/figures/umap",
help="Output file prefix")
p.add_argument("--seed", type=int, default=42)
args = p.parse_args()
if len(args.emb) != len(args.labels):
raise ValueError("--emb and --labels must have the same length")
os.makedirs(os.path.dirname(os.path.abspath(args.prefix + "_x")), exist_ok=True)
embs = [np.load(f) for f in args.emb]
comp_dfs = []
if args.compartments:
for f in args.compartments:
comp_dfs.append(pd.read_csv(f) if (f and os.path.exists(f)) else None)
stats_rows = []
for i, (emb, label) in enumerate(zip(embs, args.labels)):
print(f"\n[{label}] {emb.shape[0]} nodes × {emb.shape[1]} dims")
coords = _run_umap(emb, seed=args.seed)
tag = label.replace(" ", "_")
_plot_position(coords, label, f"{args.prefix}_{tag}_position.png")
print(f"{args.prefix}_{tag}_position.png")
comp_arr = None
sil = float("nan")
if comp_dfs and i < len(comp_dfs) and comp_dfs[i] is not None:
comp_arr = comp_dfs[i]["compartment"].values[: len(emb)]
_plot_compartment(coords, comp_arr, label,
f"{args.prefix}_{tag}_compartment.png")
print(f"{args.prefix}_{tag}_compartment.png")
sil = _silhouette(emb, comp_arr)
print(f" Silhouette (A/B, cosine): {sil:.4f}")
stats_rows.append({
"label": label,
"n_bins": emb.shape[0],
"latent_dim": emb.shape[1],
"mean_embedding_norm": float(np.linalg.norm(emb, axis=1).mean()),
"std_embedding_values": float(emb.std()),
"silhouette_AB_cosine": sil,
})
# Joint UMAP when multiple embeddings are supplied
if len(embs) > 1:
print("\nComputing joint UMAP…")
all_emb = np.vstack(embs)
all_labels = sum([[lab] * len(e) for lab, e in zip(args.labels, embs)], [])
all_coords = _run_umap(all_emb, seed=args.seed)
_plot_joint(all_coords, all_labels, f"{args.prefix}_joint.png")
print(f"{args.prefix}_joint.png")
stats_df = pd.DataFrame(stats_rows)
stats_path = f"{args.prefix}_stats.csv"
stats_df.to_csv(stats_path, index=False)
print(f"\nStats → {stats_path}")
print(stats_df.to_string(index=False))
if __name__ == "__main__":
main()