v1.0.0: VGAE applied to GM12878 vs IMR90 chr21 Hi-C at 25kb
Full reproducible pipeline: .mcool + ChIP-seq bigwigs → latent embeddings → A/B compartment calls → cross-cell comparison. Key results (chr21, 25 kb, latent dim=32): - Test AUC=0.777, AP=0.759 (converged epoch 31/300) - GM12878 A/B silhouette (cosine) = 0.775 - IMR90 zero-shot silhouette = 0.443 - A-compartment bins stable across cell types (mean cosine Δ=0.042) - B-compartment bins shift substantially (mean cosine Δ=0.451) - 101 B→A and 70 A→B compartment switches GM12878→IMR90
This commit is contained in:
48
.gitignore
vendored
48
.gitignore
vendored
@@ -1,30 +1,30 @@
|
|||||||
|
# Raw sequencing and contact data (large files; download via run_pipeline.sh)
|
||||||
|
data/raw/
|
||||||
|
|
||||||
# Python
|
# Python
|
||||||
__pycache__/
|
__pycache__/
|
||||||
*.pyc
|
*.py[cod]
|
||||||
|
*.pyo
|
||||||
|
.pytest_cache/
|
||||||
|
*.egg-info/
|
||||||
|
dist/
|
||||||
|
build/
|
||||||
|
|
||||||
# Conda / mamba envs
|
# Environments
|
||||||
.env/
|
.env
|
||||||
*.yml.lock
|
*.env
|
||||||
|
.venv/
|
||||||
|
|
||||||
# Data
|
# Editors
|
||||||
*.hic
|
.vscode/
|
||||||
*.mcool
|
.idea/
|
||||||
*.cool
|
*.swp
|
||||||
*.bam
|
*~
|
||||||
*.bw
|
|
||||||
*.bigwig
|
|
||||||
*.bed
|
|
||||||
*.pairs*
|
|
||||||
*.pt
|
|
||||||
*.npy
|
|
||||||
*.csv
|
|
||||||
*.png
|
|
||||||
|
|
||||||
# Jupyter and logs
|
# Jupyter
|
||||||
*.ipynb_checkpoints/
|
.ipynb_checkpoints/
|
||||||
*.log
|
*.ipynb
|
||||||
|
|
||||||
|
# OS
|
||||||
.DS_Store
|
.DS_Store
|
||||||
|
Thumbs.db
|
||||||
# Results / temp
|
|
||||||
results/
|
|
||||||
data/
|
|
||||||
|
|||||||
60
CITATION.cff
Normal file
60
CITATION.cff
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
cff-version: 1.2.0
|
||||||
|
message: >-
|
||||||
|
If you use this software in your research, please cite it using the
|
||||||
|
following metadata.
|
||||||
|
type: software
|
||||||
|
title: >-
|
||||||
|
chromatin-gnn: Variational Graph Autoencoder for learning latent
|
||||||
|
representations of chromatin topology from Hi-C data
|
||||||
|
authors:
|
||||||
|
- family-names: Okada
|
||||||
|
given-names: Toru
|
||||||
|
alias: ToruOkadaOi
|
||||||
|
# orcid: "https://orcid.org/XXXX-XXXX-XXXX-XXXX" # add your ORCID
|
||||||
|
version: "1.0.0"
|
||||||
|
date-released: "2024-01-01" # update to actual release date
|
||||||
|
doi: "10.5281/zenodo.XXXXXXX" # replace with actual Zenodo DOI after deposit
|
||||||
|
repository-code: "https://github.com/ToruOkadaOi/chromatin-gnn"
|
||||||
|
url: "https://github.com/ToruOkadaOi/chromatin-gnn"
|
||||||
|
license: MIT
|
||||||
|
abstract: >-
|
||||||
|
A Variational Graph Autoencoder (VGAE) applied to Hi-C chromatin contact
|
||||||
|
data to learn unsupervised latent representations of chromatin topology.
|
||||||
|
Genomic bins are modelled as graph nodes with ChIP-seq features (CTCF,
|
||||||
|
H3K27me3); normalised contact frequencies define weighted edges.
|
||||||
|
The model is trained on GM12878 lymphoblastoid cells and evaluated on
|
||||||
|
both link-prediction (AUROC/AP) and the biological interpretability of
|
||||||
|
the latent space against known A/B compartments.
|
||||||
|
keywords:
|
||||||
|
- chromatin
|
||||||
|
- Hi-C
|
||||||
|
- graph neural network
|
||||||
|
- variational autoencoder
|
||||||
|
- VGAE
|
||||||
|
- A/B compartments
|
||||||
|
- topologically associating domains
|
||||||
|
- TAD
|
||||||
|
- epigenomics
|
||||||
|
- 3D genome organisation
|
||||||
|
references:
|
||||||
|
- type: article
|
||||||
|
title: >-
|
||||||
|
A 3D Map of the Human Genome at Kilobase Resolution Reveals
|
||||||
|
Principles of Chromatin Looping
|
||||||
|
authors:
|
||||||
|
- family-names: Rao
|
||||||
|
given-names: "Suhas S. P."
|
||||||
|
- family-names: Huntley
|
||||||
|
given-names: "Miriam H."
|
||||||
|
year: 2014
|
||||||
|
journal: Cell
|
||||||
|
doi: 10.1016/j.cell.2014.11.021
|
||||||
|
- type: article
|
||||||
|
title: "Variational Graph Auto-Encoders"
|
||||||
|
authors:
|
||||||
|
- family-names: Kipf
|
||||||
|
given-names: "Thomas N."
|
||||||
|
- family-names: Welling
|
||||||
|
given-names: Max
|
||||||
|
year: 2016
|
||||||
|
url: "https://arxiv.org/abs/1611.07308"
|
||||||
309
README.md
309
README.md
@@ -1,19 +1,302 @@
|
|||||||
# Graph learning for genome architecture
|
# chromatin-gnn
|
||||||
|
|
||||||
### workflow
|
**Variational Graph Autoencoder for learning latent representations of chromatin topology from Hi-C data**
|
||||||
|
|
||||||
|
[](https://doi.org/10.5281/zenodo.XXXXXXX)
|
||||||
|
[](LICENSE)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The three-dimensional organisation of chromatin in the nucleus is not random. Chromosomes fold into compartments, topologically associating domains (TADs), and loop structures that correlate strongly with gene regulation. Hi-C sequencing captures these contacts genome-wide, but the resulting data are high-dimensional and require principled dimensionality reduction to extract biologically interpretable structure.
|
||||||
|
|
||||||
|
This repository applies a **Variational Graph Autoencoder (VGAE)** to Hi-C contact data to learn a compact, continuous latent representation of chromatin topology. Genomic bins are treated as graph nodes; normalised contact frequencies define weighted edges; ChIP-seq tracks for CTCF and H3K27me3 supply node features. The model is trained end-to-end on a link-prediction objective and evaluated for its ability to recover known biological structure — A/B compartments — in an entirely unsupervised manner.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Scientific question
|
||||||
|
|
||||||
|
> Can a VGAE learn biologically meaningful latent representations of chromatin topology — capturing A/B compartments and cell-type-specific reorganisation — from Hi-C contact data alone, in an unsupervised manner?
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
```
|
||||||
|
Node features (2D: CTCF, H3K27me3)
|
||||||
|
│
|
||||||
|
BatchNorm
|
||||||
|
│
|
||||||
|
GCNConv(64) ← shared message-passing layer
|
||||||
|
│
|
||||||
|
ReLU + Dropout(0.2)
|
||||||
|
/ \
|
||||||
|
GCNConv(32) GCNConv(32)
|
||||||
|
μ log σ
|
||||||
|
\ /
|
||||||
|
Reparameterisation
|
||||||
|
│
|
||||||
|
z ∈ ℝ³² (node embeddings)
|
||||||
|
│
|
||||||
|
Inner-product decoder
|
||||||
|
(link prediction objective: binary cross-entropy + KL divergence)
|
||||||
|
```
|
||||||
|
|
||||||
|
The encoder is a two-layer Graph Convolutional Network (Kipf & Welling 2016, 2017) with a BatchNorm input layer. The decoder is the standard dot-product decoder used in the original VGAE paper. Training uses a link-prediction objective: the model is asked to distinguish real Hi-C contacts from randomly sampled non-contacts.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Dataset
|
||||||
|
|
||||||
|
All data are from the GRCh38/hg38 reference genome, chromosome 21 at 25 kb resolution.
|
||||||
|
|
||||||
|
| File | Cell line | Type | Source | Accession |
|
||||||
|
|------|-----------|------|--------|-----------|
|
||||||
|
| GM12878.mcool | GM12878 (lymphoblastoid) | Hi-C contact matrix | 4DN Data Portal | 4DNFIRUMEC32 |
|
||||||
|
| IMR90.mcool | IMR-90 (lung fibroblast) | Hi-C contact matrix | 4DN Data Portal | 4DNFIABB3FHQ |
|
||||||
|
| GM12878_CTCF.bw | GM12878 | CTCF ChIP-seq (FC/control) | ENCODE | ENCFF741BAQ (exp. ENCSR000AKB) |
|
||||||
|
| GM12878_H3K27me3.bw | GM12878 | H3K27me3 ChIP-seq (FC/control) | ENCODE | ENCFF736CNQ (exp. ENCSR000AKD) |
|
||||||
|
| IMR90_CTCF.bw | IMR-90 | CTCF ChIP-seq (FC/control) | ENCODE | ENCFF770DUD (exp. ENCSR000EFI) |
|
||||||
|
| IMR90_H3K27me3.bw | IMR-90 | H3K27me3 ChIP-seq (FC/control) | ENCODE | ENCFF158HZL (exp. ENCSR431UUY) |
|
||||||
|
|
||||||
|
**Graph statistics:**
|
||||||
|
|
||||||
|
| Cell line | Bins (chr21, 25 kb) | Edges (contacts) | Node features |
|
||||||
|
|-----------|---------------------|------------------|---------------|
|
||||||
|
| GM12878 | 1,869 | 87,557 | 2 (CTCF, H3K27me3) |
|
||||||
|
| IMR90 | 1,869 | 136,121 | 2 (CTCF, H3K27me3) |
|
||||||
|
|
||||||
|
IMR90 has ~55% more intra-chromosomal contacts than GM12878 at chr21, suggesting a more compact or contact-rich chromatin organisation in this fibroblast cell line.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Build graph (contact matrix & bigwig needed)
|
conda create -n chromatin_gnn python=3.10 -y
|
||||||
python scripts/build_graph.py --mcool x.mcool --chrom chrx --res x
|
conda activate chromatin_gnn
|
||||||
--bigwigs xCTCFx.bw xH3K27me3x.bw --out data/chrx_xconditionx.pt
|
|
||||||
|
|
||||||
# Train model
|
# CPU-only PyTorch (replace URL for GPU builds)
|
||||||
python scripts/train_vgae.py --graph data/chrx_xconditionx.pt --epochs 100 --outdir results
|
pip install torch==2.1.2 --index-url https://download.pytorch.org/whl/cpu
|
||||||
|
|
||||||
# Encode treatment graph
|
# All other dependencies
|
||||||
python scripts/encode_graph.py --model results/model.pt --graph data/chrx_xtreatmentx.pt --out results/emb_xtreatmentx.npy
|
pip install torch-geometric==2.5.3 cooler==0.9.3 pyBigWig pandas \
|
||||||
|
"numpy>=1.24,<2.0" scikit-learn matplotlib umap-learn scipy seaborn tqdm
|
||||||
|
```
|
||||||
|
|
||||||
# Compare embeddings
|
> **Note:** `cooler==0.9.3` requires `numpy<2.0`. The env.yml captures the exact versions used for this release.
|
||||||
python scripts/compare_embeddings_general.py --emb1 results/emb.npy --emb2 results/emb_xtreatmentx.npy \
|
|
||||||
--label1 xControlx --label2 xTreatmentx --prefix results/chrx
|
---
|
||||||
```
|
|
||||||
|
## Workflow
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Full end-to-end run (downloads bigwigs automatically; .mcool files must be present in data/raw/)
|
||||||
|
bash run_pipeline.sh
|
||||||
|
|
||||||
|
# Or run individual steps:
|
||||||
|
|
||||||
|
# 1. Build contact graph
|
||||||
|
python scripts/build_graph.py \
|
||||||
|
--mcool data/raw/GM12878.mcool \
|
||||||
|
--chrom chr21 --res 25000 \
|
||||||
|
--bigwigs data/raw/GM12878_CTCF.bw data/raw/GM12878_H3K27me3.bw \
|
||||||
|
--out data/processed/GM12878_chr21.pt
|
||||||
|
|
||||||
|
# 2. Compute A/B compartments
|
||||||
|
python scripts/compute_compartments.py \
|
||||||
|
--mcool data/raw/GM12878.mcool --chrom chr21 --res 25000 \
|
||||||
|
--bigwig_orient data/raw/GM12878_CTCF.bw \
|
||||||
|
--out results/GM12878/compartments_chr21.csv
|
||||||
|
|
||||||
|
# 3. Train VGAE
|
||||||
|
python scripts/train_vgae.py \
|
||||||
|
--graph data/processed/GM12878_chr21.pt \
|
||||||
|
--epochs 300 --patience 20 --hidden 64 --latent 32 \
|
||||||
|
--outdir results/GM12878
|
||||||
|
|
||||||
|
# 4. Encode a second cell line with the trained model
|
||||||
|
python scripts/encode_graph.py \
|
||||||
|
--model results/GM12878/model.pt \
|
||||||
|
--graph data/processed/IMR90_chr21.pt \
|
||||||
|
--out results/IMR90/emb.npy
|
||||||
|
|
||||||
|
# 5. UMAP visualisation
|
||||||
|
python scripts/visualize_embeddings.py \
|
||||||
|
--emb results/GM12878/emb.npy results/IMR90/emb.npy \
|
||||||
|
--labels GM12878 IMR90 \
|
||||||
|
--compartments results/GM12878/compartments_chr21.csv \
|
||||||
|
results/IMR90/compartments_chr21.csv \
|
||||||
|
--prefix results/figures/umap
|
||||||
|
|
||||||
|
# 6. Per-bin embedding comparison
|
||||||
|
python scripts/compare_embeddings.py \
|
||||||
|
--emb1 results/GM12878/emb.npy --emb2 results/IMR90/emb.npy \
|
||||||
|
--label1 GM12878 --label2 IMR90 \
|
||||||
|
--prefix results/figures/chr21
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Results
|
||||||
|
|
||||||
|
### Training (GM12878, chr21, 25 kb)
|
||||||
|
|
||||||
|
| Metric | Value |
|
||||||
|
|--------|-------|
|
||||||
|
| Epochs to convergence | 31 / 300 (early stopping, patience=20) |
|
||||||
|
| Validation AUC (link prediction) | 0.774 |
|
||||||
|
| Test AUC | 0.777 |
|
||||||
|
| Test AP | 0.759 |
|
||||||
|
| Latent dimensionality | 32 |
|
||||||
|
|
||||||
|
The model converged rapidly, suggesting that the graph structure of chr21 at 25 kb is learnable with a shallow two-layer GCN.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### A/B compartment separation in the latent space
|
||||||
|
|
||||||
|
The UMAP of GM12878 node embeddings coloured by A/B compartment shows strong, clean separation of the two compartment types without the model ever receiving compartment labels during training.
|
||||||
|
|
||||||
|
| Cell line | Silhouette score (A/B, cosine) | A bins | B bins | Masked (N) |
|
||||||
|
|-----------|-------------------------------|--------|--------|------------|
|
||||||
|
| GM12878 (training) | **0.775** | 602 | 683 | 584 |
|
||||||
|
| IMR90 (zero-shot) | 0.443 | 614 | 709 | 546 |
|
||||||
|
|
||||||
|
The GM12878 silhouette of **0.775** indicates that the VGAE has learned a latent space in which A and B compartments are nearly linearly separable — a strong signal given that compartment identity was never provided as a training label.
|
||||||
|
|
||||||
|
For IMR90, encoded zero-shot with the GM12878-trained model, the silhouette drops to **0.443**. This is expected: the model's BatchNorm statistics were fit to GM12878, and IMR90's chromatin organisation partially diverges.
|
||||||
|
|
||||||
|
**Figures:**
|
||||||
|
|
||||||
|
| Figure | Description |
|
||||||
|
|--------|-------------|
|
||||||
|
| `results/figures/umap_GM12878_compartment.png` | GM12878 UMAP coloured by A/B compartment |
|
||||||
|
| `results/figures/umap_GM12878_position.png` | GM12878 UMAP coloured by genomic position |
|
||||||
|
| `results/figures/umap_IMR90_compartment.png` | IMR90 UMAP coloured by A/B compartment |
|
||||||
|
| `results/figures/umap_joint.png` | Joint UMAP of both cell lines |
|
||||||
|
| `results/figures/chr21_delta.png` | Per-bin cosine distance track (GM12878 vs IMR90) |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Cell-type comparison: GM12878 vs IMR90
|
||||||
|
|
||||||
|
The IMR90 graph was encoded with the GM12878-trained model (zero-shot transfer). Per-bin cosine distances between the two embedding matrices reveal which genomic loci undergo the largest chromatin reorganisation between cell types.
|
||||||
|
|
||||||
|
**Per-bin cosine distance summary:**
|
||||||
|
|
||||||
|
| Statistic | Value |
|
||||||
|
|-----------|-------|
|
||||||
|
| Mean | 0.245 |
|
||||||
|
| Median | 0.028 |
|
||||||
|
| Bins with distance < 0.1 (stable) | 968 / 1,869 (52%) |
|
||||||
|
| Bins with distance > 0.5 (high shift) | 293 / 1,869 (16%) |
|
||||||
|
|
||||||
|
**Mean cosine distance by GM12878 compartment:**
|
||||||
|
|
||||||
|
| Compartment | Mean distance | Median distance |
|
||||||
|
|-------------|--------------|-----------------|
|
||||||
|
| A (active) | 0.042 | 0.001 |
|
||||||
|
| B (repressive) | 0.451 | 0.352 |
|
||||||
|
| N (masked) | 0.213 | 0.000 |
|
||||||
|
|
||||||
|
**Key finding:** A-compartment bins are nearly invariant between the two cell types (mean Δ = 0.042), while B-compartment bins shift substantially (mean Δ = 0.451). This is consistent with the known biology: constitutively active chromatin domains tend to be conserved across cell types, while heterochromatic B-compartment organisation is more cell-type-specific.
|
||||||
|
|
||||||
|
**Compartment switches (GM12878 → IMR90):**
|
||||||
|
|
||||||
|
| GM12878 | IMR90 | Bins | Interpretation |
|
||||||
|
|---------|-------|------|----------------|
|
||||||
|
| A | A | 493 | Stable active |
|
||||||
|
| B | B | 581 | Stable repressive |
|
||||||
|
| B → A | A | 101 | Loci that open in IMR90 |
|
||||||
|
| A → B | B | 70 | Loci that close in IMR90 |
|
||||||
|
|
||||||
|
101 loci switch from B (repressive in GM12878) to A (active in IMR90), versus 70 in the reverse direction. This asymmetry suggests that IMR90 fibroblasts activate more lineage-specific loci on chr21 than GM12878 lymphoblastoid cells.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Biological interpretation
|
||||||
|
|
||||||
|
The results demonstrate that a VGAE trained on Hi-C data **recovers A/B compartment structure without supervision** (silhouette = 0.775). The latent space organises chr21 bins according to their chromatin state, not just their linear genomic position — the UMAP coloured by genomic position shows a broadly continuous gradient, while the compartment-coloured UMAP shows discrete clusters.
|
||||||
|
|
||||||
|
The zero-shot application to IMR90 captures the partial conservation of compartment organisation across cell types. The B-compartment instability revealed by the cosine distance analysis is consistent with the literature: heterochromatin rewiring is a known driver of cell-type identity (Lieberman-Aiden et al. 2009; Dixon et al. 2015).
|
||||||
|
|
||||||
|
Notably, the model achieves this with only two node features (CTCF and H3K27me3 signal), demonstrating that a modest set of epigenomic marks, combined with contact topology, is sufficient to encode the major axis of chromatin organisation.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Limitations
|
||||||
|
|
||||||
|
1. **Single chromosome, single resolution.** Results are for chr21 at 25 kb only. Chr21 is acrocentric with a large masked pericentromeric region (584 / 1,869 bins masked in GM12878), which may reduce statistical power compared to gene-rich autosomes.
|
||||||
|
|
||||||
|
2. **Shallow encoder.** The two-layer GCN has a local receptive field (2-hop neighbourhood). Long-range chromatin interactions spanning multiple TADs are not directly encoded. Deeper networks or attention-based architectures may capture these better.
|
||||||
|
|
||||||
|
3. **Link-prediction objective ≠ compartment recovery.** The model is optimised to predict contacts, not compartments. The strong silhouette score is emergent, not guaranteed. The objective could be supplemented with biologically-informed losses.
|
||||||
|
|
||||||
|
4. **Zero-shot transfer with fixed BatchNorm.** Encoding IMR90 with GM12878 BatchNorm statistics means the model sees IMR90 features in GM12878's normalisation frame. A domain-adaptation approach (e.g., re-fitting BatchNorm on IMR90 with frozen GCN weights) would give a fairer comparison.
|
||||||
|
|
||||||
|
5. **Compartment calling is approximate.** The O/E → Pearson correlation → PC1 pipeline is sensitive to the choice of orientation signal (CTCF here). Bins with low coverage are masked and assigned no compartment label, which affects the silhouette calculation.
|
||||||
|
|
||||||
|
6. **No TAD-level evaluation.** TAD boundary detection would require a graph-level or boundary-aware metric. The current evaluation is node-level only.
|
||||||
|
|
||||||
|
7. **No statistical significance testing.** The compartment switch counts (101 B→A, 70 A→B) have not been tested for significance against a null model (e.g., random permutation of compartment labels).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Future work
|
||||||
|
|
||||||
|
- Apply to all autosomes and compare genome-wide compartment recovery.
|
||||||
|
- Add a TAD-boundary evaluation metric (e.g., insulation score correlation with latent space gradients).
|
||||||
|
- Fine-tune on IMR90 (transfer learning) to improve the IMR90 silhouette score.
|
||||||
|
- Add cohesin depletion or auxin-inducible degron (AID) perturbation data as a controlled condition comparison.
|
||||||
|
- Replace the inner-product decoder with a distance-aware decoder that incorporates linear genomic distance.
|
||||||
|
- Benchmark against PCA/UMAP of the raw contact matrix and against other graph-based methods (GraphSAGE, GAT).
|
||||||
|
- Extend node features to include additional histone marks (H3K4me3, H3K27ac, H3K9me3) to test whether richer epigenomic context improves compartment recovery.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Scripts
|
||||||
|
|
||||||
|
| Script | Purpose |
|
||||||
|
|--------|---------|
|
||||||
|
| `scripts/build_graph.py` | Convert .mcool + bigWigs → PyG Data object |
|
||||||
|
| `scripts/train_vgae.py` | Train VGAE with link-prediction objective + early stopping |
|
||||||
|
| `scripts/encode_graph.py` | Encode a new graph with a trained model |
|
||||||
|
| `scripts/compute_compartments.py` | O/E matrix → Pearson PCA → A/B compartment calls |
|
||||||
|
| `scripts/visualize_embeddings.py` | UMAP + compartment visualisation + silhouette |
|
||||||
|
| `scripts/compare_embeddings.py` | Per-bin cosine/L2/L1 distance between two embeddings |
|
||||||
|
| `scripts/model.py` | Shared Encoder class (imported by train and encode scripts) |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Citation
|
||||||
|
|
||||||
|
If you use this code or data in your research, please cite:
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@software{okada_chromatin_gnn_2024,
|
||||||
|
author = {Okada, Toru},
|
||||||
|
title = {{chromatin-gnn: Variational Graph Autoencoder for learning
|
||||||
|
latent representations of chromatin topology from Hi-C data}},
|
||||||
|
year = {2024},
|
||||||
|
doi = {10.5281/zenodo.XXXXXXX},
|
||||||
|
url = {https://github.com/ToruOkadaOi/chromatin-gnn},
|
||||||
|
version = {1.0.0}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
See also `CITATION.cff` for CFF-format metadata.
|
||||||
|
|
||||||
|
**Key references:**
|
||||||
|
|
||||||
|
- Kipf, T. N. & Welling, M. (2016). Variational Graph Auto-Encoders. *arXiv:1611.07308*
|
||||||
|
- Kipf, T. N. & Welling, M. (2017). Semi-Supervised Classification with Graph Convolutional Networks. *ICLR 2017*
|
||||||
|
- Rao, S. S. P. et al. (2014). A 3D Map of the Human Genome at Kilobase Resolution Reveals Principles of Chromatin Looping. *Cell*, 159(7), 1665–1680. https://doi.org/10.1016/j.cell.2014.11.021
|
||||||
|
- Lieberman-Aiden, E. et al. (2009). Comprehensive mapping of long-range interactions reveals folding principles of the human genome. *Science*, 326(5950), 289–293.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
MIT — see [LICENSE](LICENSE).
|
||||||
|
|||||||
BIN
data/processed/GM12878_chr21.pt
Normal file
BIN
data/processed/GM12878_chr21.pt
Normal file
Binary file not shown.
BIN
data/processed/IMR90_chr21.pt
Normal file
BIN
data/processed/IMR90_chr21.pt
Normal file
Binary file not shown.
36
env.yml
36
env.yml
@@ -1,21 +1,27 @@
|
|||||||
name: chromatin_gnn_aman
|
name: chromatin_gnn
|
||||||
|
# Installation (pip-based, conda used only for Python):
|
||||||
|
# conda create -n chromatin_gnn python=3.10 -y
|
||||||
|
# conda activate chromatin_gnn
|
||||||
|
# pip install torch==2.1.2 --index-url https://download.pytorch.org/whl/cpu
|
||||||
|
# pip install -r requirements.txt
|
||||||
|
#
|
||||||
|
# For GPU support replace the torch line with:
|
||||||
|
# pip install torch==2.1.2 --index-url https://download.pytorch.org/whl/cu118
|
||||||
channels:
|
channels:
|
||||||
- pytorch
|
|
||||||
- conda-forge
|
|
||||||
- defaults
|
- defaults
|
||||||
dependencies:
|
dependencies:
|
||||||
- python=3.10
|
- python=3.10
|
||||||
- pytorch
|
|
||||||
- torchvision
|
|
||||||
- torchaudio
|
|
||||||
- cooler
|
|
||||||
- pybigwig
|
|
||||||
- pandas
|
|
||||||
- numpy
|
|
||||||
- scikit-learn
|
|
||||||
- matplotlib
|
|
||||||
- umap-learn
|
|
||||||
- pip
|
- pip
|
||||||
- pip:
|
- pip:
|
||||||
- torch-geometric
|
- torch==2.1.2 # CPU build; see GPU note above
|
||||||
|
- torch-geometric==2.5.3
|
||||||
|
- cooler==0.9.3
|
||||||
|
- pyBigWig==0.3.25
|
||||||
|
- pandas==2.1.4
|
||||||
|
- "numpy>=1.24,<2.0" # cooler 0.9.3 requires numpy<2.0
|
||||||
|
- scikit-learn==1.4.2
|
||||||
|
- matplotlib==3.8.4
|
||||||
|
- umap-learn==0.5.12
|
||||||
|
- scipy==1.12.0
|
||||||
|
- seaborn==0.13.2
|
||||||
|
- tqdm==4.66.2
|
||||||
|
|||||||
1870
results/GM12878/compartments_chr21.csv
Normal file
1870
results/GM12878/compartments_chr21.csv
Normal file
File diff suppressed because it is too large
Load Diff
BIN
results/GM12878/emb.npy
Normal file
BIN
results/GM12878/emb.npy
Normal file
Binary file not shown.
13
results/GM12878/metrics.json
Normal file
13
results/GM12878/metrics.json
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
{
|
||||||
|
"val_auc": 0.774052272942853,
|
||||||
|
"test_auc": 0.7767560244060842,
|
||||||
|
"test_ap": 0.7585872967832136,
|
||||||
|
"epochs_ran": 31,
|
||||||
|
"epochs_max": 300,
|
||||||
|
"patience": 20,
|
||||||
|
"hidden": 64,
|
||||||
|
"latent": 32,
|
||||||
|
"dropout": 0.2,
|
||||||
|
"lr": 0.001,
|
||||||
|
"seed": 42
|
||||||
|
}
|
||||||
BIN
results/GM12878/model.pt
Normal file
BIN
results/GM12878/model.pt
Normal file
Binary file not shown.
1870
results/IMR90/compartments_chr21.csv
Normal file
1870
results/IMR90/compartments_chr21.csv
Normal file
File diff suppressed because it is too large
Load Diff
BIN
results/IMR90/emb.npy
Normal file
BIN
results/IMR90/emb.npy
Normal file
Binary file not shown.
1870
results/figures/chr21_delta.csv
Normal file
1870
results/figures/chr21_delta.csv
Normal file
File diff suppressed because it is too large
Load Diff
BIN
results/figures/chr21_delta.png
Normal file
BIN
results/figures/chr21_delta.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 219 KiB |
BIN
results/figures/umap_GM12878_compartment.png
Normal file
BIN
results/figures/umap_GM12878_compartment.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 189 KiB |
BIN
results/figures/umap_GM12878_position.png
Normal file
BIN
results/figures/umap_GM12878_position.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 238 KiB |
BIN
results/figures/umap_IMR90_compartment.png
Normal file
BIN
results/figures/umap_IMR90_compartment.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 204 KiB |
BIN
results/figures/umap_IMR90_position.png
Normal file
BIN
results/figures/umap_IMR90_position.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 267 KiB |
BIN
results/figures/umap_joint.png
Normal file
BIN
results/figures/umap_joint.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 304 KiB |
3
results/figures/umap_stats.csv
Normal file
3
results/figures/umap_stats.csv
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
label,n_bins,latent_dim,mean_embedding_norm,std_embedding_values,silhouette_AB_cosine
|
||||||
|
GM12878,1869,32,0.6679670810699463,0.1371903121471405,0.7748898267745972
|
||||||
|
IMR90,1869,32,0.7111130356788635,0.14772675931453705,0.4431541860103607
|
||||||
|
162
run_pipeline.sh
Normal file
162
run_pipeline.sh
Normal file
@@ -0,0 +1,162 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
#
|
||||||
|
# run_pipeline.sh – end-to-end pipeline for chromatin-gnn
|
||||||
|
#
|
||||||
|
# Prerequisites
|
||||||
|
# -------------
|
||||||
|
# 1. Create and activate the environment:
|
||||||
|
# conda create -n chromatin_gnn python=3.10 -y
|
||||||
|
# conda activate chromatin_gnn
|
||||||
|
# pip install torch==2.1.2 --index-url https://download.pytorch.org/whl/cpu
|
||||||
|
# pip install torch-geometric==2.5.3 cooler==0.9.3 pyBigWig pandas \
|
||||||
|
# "numpy>=1.24,<2.0" scikit-learn matplotlib umap-learn scipy seaborn tqdm
|
||||||
|
#
|
||||||
|
# 2. Download raw data into data/raw/ (see README.md § Dataset for URLs):
|
||||||
|
# GM12878.mcool 4DN accession 4DNFIRUMEC32
|
||||||
|
# IMR90.mcool 4DN accession 4DNFIABB3FHQ
|
||||||
|
# GM12878_CTCF.bw ENCODE ENCFF741BAQ (experiment ENCSR000AKB)
|
||||||
|
# GM12878_H3K27me3.bw ENCODE ENCFF736CNQ (experiment ENCSR000AKD)
|
||||||
|
# IMR90_CTCF.bw ENCODE ENCFF770DUD (experiment ENCSR000EFI)
|
||||||
|
# IMR90_H3K27me3.bw ENCODE ENCFF158HZL (experiment ENCSR431UUY)
|
||||||
|
#
|
||||||
|
# Usage
|
||||||
|
# -----
|
||||||
|
# bash run_pipeline.sh [--chrom chr21] [--res 25000] [--epochs 300]
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
# ========== Configuration ==========
|
||||||
|
CHROM="${CHROM:-chr21}"
|
||||||
|
RES="${RES:-25000}"
|
||||||
|
EPOCHS="${EPOCHS:-300}"
|
||||||
|
PATIENCE="${PATIENCE:-20}"
|
||||||
|
HIDDEN="${HIDDEN:-64}"
|
||||||
|
LATENT="${LATENT:-32}"
|
||||||
|
SEED="${SEED:-42}"
|
||||||
|
|
||||||
|
REPO="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||||
|
SCRIPTS="$REPO/scripts"
|
||||||
|
DATA="$REPO/data"
|
||||||
|
RESULTS="$REPO/results"
|
||||||
|
|
||||||
|
# ========== Directories ==========
|
||||||
|
mkdir -p "$DATA/raw" "$DATA/processed" \
|
||||||
|
"$RESULTS/GM12878" "$RESULTS/IMR90" "$RESULTS/figures"
|
||||||
|
|
||||||
|
# ========== Download ENCODE bigWig tracks ==========
|
||||||
|
echo "=== Downloading ENCODE bigWig tracks ==="
|
||||||
|
for entry in \
|
||||||
|
"GM12878_CTCF.bw|https://www.encodeproject.org/files/ENCFF741BAQ/@@download/ENCFF741BAQ.bigWig" \
|
||||||
|
"GM12878_H3K27me3.bw|https://www.encodeproject.org/files/ENCFF736CNQ/@@download/ENCFF736CNQ.bigWig" \
|
||||||
|
"IMR90_CTCF.bw|https://www.encodeproject.org/files/ENCFF770DUD/@@download/ENCFF770DUD.bigWig" \
|
||||||
|
"IMR90_H3K27me3.bw|https://www.encodeproject.org/files/ENCFF158HZL/@@download/ENCFF158HZL.bigWig"
|
||||||
|
do
|
||||||
|
fname="${entry%%|*}"
|
||||||
|
url="${entry##*|}"
|
||||||
|
out="$DATA/raw/$fname"
|
||||||
|
if [ -f "$out" ]; then
|
||||||
|
echo " $fname already present, skipping"
|
||||||
|
else
|
||||||
|
echo " Downloading $fname ..."
|
||||||
|
wget -q --show-progress -O "$out" "$url"
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
# .mcool files must be downloaded manually from 4DN (requires free account):
|
||||||
|
# GM12878: https://data.4dnucleome.org/files-processed/4DNFIRUMEC32/@@download/4DNFIRUMEC32.mcool
|
||||||
|
# IMR90: https://data.4dnucleome.org/files-processed/4DNFIABB3FHQ/@@download/4DNFIABB3FHQ.mcool
|
||||||
|
for f in GM12878.mcool IMR90.mcool; do
|
||||||
|
if [ ! -f "$DATA/raw/$f" ]; then
|
||||||
|
echo "ERROR: $DATA/raw/$f not found. Download from 4DN (see README) and retry." >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
# ========== Step 1: Build contact graphs ==========
|
||||||
|
echo ""
|
||||||
|
echo "=== Step 1: Building chromatin contact graphs ==="
|
||||||
|
for CELL in GM12878 IMR90; do
|
||||||
|
OUT="$DATA/processed/${CELL}_${CHROM}.pt"
|
||||||
|
if [ -f "$OUT" ]; then
|
||||||
|
echo " ${CELL} graph already exists, skipping"
|
||||||
|
else
|
||||||
|
python "$SCRIPTS/build_graph.py" \
|
||||||
|
--mcool "$DATA/raw/${CELL}.mcool" \
|
||||||
|
--chrom "$CHROM" --res "$RES" \
|
||||||
|
--bigwigs "$DATA/raw/${CELL}_CTCF.bw" "$DATA/raw/${CELL}_H3K27me3.bw" \
|
||||||
|
--out "$OUT"
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
# ========== Step 2: Compute A/B compartments ==========
|
||||||
|
echo ""
|
||||||
|
echo "=== Step 2: Computing A/B compartments (PC1 of O/E Pearson correlation) ==="
|
||||||
|
for CELL in GM12878 IMR90; do
|
||||||
|
OUT="$RESULTS/${CELL}/compartments_${CHROM}.csv"
|
||||||
|
if [ -f "$OUT" ]; then
|
||||||
|
echo " ${CELL} compartments already exist, skipping"
|
||||||
|
else
|
||||||
|
python "$SCRIPTS/compute_compartments.py" \
|
||||||
|
--mcool "$DATA/raw/${CELL}.mcool" \
|
||||||
|
--chrom "$CHROM" --res "$RES" \
|
||||||
|
--bigwig_orient "$DATA/raw/${CELL}_CTCF.bw" \
|
||||||
|
--out "$OUT"
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
# ========== Step 3: Train VGAE on GM12878 ==========
|
||||||
|
echo ""
|
||||||
|
echo "=== Step 3: Training VGAE on GM12878 ==="
|
||||||
|
if [ -f "$RESULTS/GM12878/model.pt" ]; then
|
||||||
|
echo " Trained model already exists, skipping"
|
||||||
|
else
|
||||||
|
python "$SCRIPTS/train_vgae.py" \
|
||||||
|
--graph "$DATA/processed/GM12878_${CHROM}.pt" \
|
||||||
|
--epochs "$EPOCHS" --patience "$PATIENCE" \
|
||||||
|
--hidden "$HIDDEN" --latent "$LATENT" \
|
||||||
|
--seed "$SEED" \
|
||||||
|
--outdir "$RESULTS/GM12878"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ========== Step 4: Encode IMR90 with GM12878 model ==========
|
||||||
|
echo ""
|
||||||
|
echo "=== Step 4: Encoding IMR90 graph with trained GM12878 model ==="
|
||||||
|
if [ -f "$RESULTS/IMR90/emb.npy" ]; then
|
||||||
|
echo " IMR90 embeddings already exist, skipping"
|
||||||
|
else
|
||||||
|
python "$SCRIPTS/encode_graph.py" \
|
||||||
|
--model "$RESULTS/GM12878/model.pt" \
|
||||||
|
--graph "$DATA/processed/IMR90_${CHROM}.pt" \
|
||||||
|
--out "$RESULTS/IMR90/emb.npy"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ========== Step 5: Visualise embeddings ==========
|
||||||
|
echo ""
|
||||||
|
echo "=== Step 5: Generating UMAP visualisations ==="
|
||||||
|
python "$SCRIPTS/visualize_embeddings.py" \
|
||||||
|
--emb "$RESULTS/GM12878/emb.npy" "$RESULTS/IMR90/emb.npy" \
|
||||||
|
--labels GM12878 IMR90 \
|
||||||
|
--compartments \
|
||||||
|
"$RESULTS/GM12878/compartments_${CHROM}.csv" \
|
||||||
|
"$RESULTS/IMR90/compartments_${CHROM}.csv" \
|
||||||
|
--prefix "$RESULTS/figures/umap" \
|
||||||
|
--seed "$SEED"
|
||||||
|
|
||||||
|
# ========== Step 6: Compare embeddings ==========
|
||||||
|
echo ""
|
||||||
|
echo "=== Step 6: Comparing GM12878 vs IMR90 embeddings ==="
|
||||||
|
python "$SCRIPTS/compare_embeddings.py" \
|
||||||
|
--emb1 "$RESULTS/GM12878/emb.npy" \
|
||||||
|
--emb2 "$RESULTS/IMR90/emb.npy" \
|
||||||
|
--label1 GM12878 --label2 IMR90 \
|
||||||
|
--prefix "$RESULTS/figures/${CHROM}"
|
||||||
|
|
||||||
|
# ========== Summary ==========
|
||||||
|
echo ""
|
||||||
|
echo "=== Pipeline complete ==="
|
||||||
|
echo "Outputs:"
|
||||||
|
echo " Model + embeddings : $RESULTS/GM12878/"
|
||||||
|
echo " Figures : $RESULTS/figures/"
|
||||||
|
echo " Metrics : $RESULTS/GM12878/metrics.json"
|
||||||
|
echo ""
|
||||||
|
cat "$RESULTS/GM12878/metrics.json"
|
||||||
@@ -50,7 +50,8 @@ def main():
|
|||||||
if emb1.shape != emb2.shape:
|
if emb1.shape != emb2.shape:
|
||||||
raise ValueError(f"Shape mismatch: {emb1.shape} vs {emb2.shape}")
|
raise ValueError(f"Shape mismatch: {emb1.shape} vs {emb2.shape}")
|
||||||
|
|
||||||
os.makedirs(os.path.dirname(args.prefix), exist_ok=True)
|
prefix_dir = os.path.dirname(os.path.abspath(args.prefix))
|
||||||
|
os.makedirs(prefix_dir, exist_ok=True)
|
||||||
n_bins, n_dim = emb1.shape
|
n_bins, n_dim = emb1.shape
|
||||||
print(f"Loaded embeddings: {n_bins} bins × {n_dim} dims")
|
print(f"Loaded embeddings: {n_bins} bins × {n_dim} dims")
|
||||||
|
|
||||||
|
|||||||
170
scripts/compute_compartments.py
Normal file
170
scripts/compute_compartments.py
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Compute A/B chromatin compartments from a Hi-C .mcool file.
|
||||||
|
|
||||||
|
Algorithm
|
||||||
|
---------
|
||||||
|
1. Load ICE-balanced contact matrix for the target chromosome.
|
||||||
|
2. Distance-normalise to O/E (divide each diagonal by its mean contact frequency).
|
||||||
|
3. Compute Pearson correlation matrix of the O/E rows.
|
||||||
|
4. PCA of the correlation matrix; PC1 distinguishes A from B compartments.
|
||||||
|
5. Orient the PC1 sign using --bigwig_orient (e.g. CTCF):
|
||||||
|
positive PC1 → high signal in that track.
|
||||||
|
With CTCF: positive PC1 = CTCF-enriched = A compartment (active).
|
||||||
|
With H3K27me3: pass --flip_orient so positive PC1 = B compartment (repressive).
|
||||||
|
|
||||||
|
Output
|
||||||
|
------
|
||||||
|
CSV with columns: chrom, start, end, pc1, compartment (A / B / N for masked bins).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import cooler
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
from sklearn.decomposition import PCA
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.dirname(__file__))
|
||||||
|
|
||||||
|
|
||||||
|
def _bin_bigwig(bw_path: str, chrom: str, bins) -> np.ndarray:
|
||||||
|
"""Average bigWig signal over a list of (start, end) genomic bins."""
|
||||||
|
import pyBigWig
|
||||||
|
bw = pyBigWig.open(bw_path)
|
||||||
|
chrom_len = bw.chroms().get(chrom, 0)
|
||||||
|
vals = []
|
||||||
|
for s, e in bins:
|
||||||
|
s, e = max(0, int(s)), min(chrom_len, int(e))
|
||||||
|
if s >= e:
|
||||||
|
vals.append(0.0)
|
||||||
|
continue
|
||||||
|
v = bw.stats(chrom, s, e, type="mean")[0]
|
||||||
|
vals.append(0.0 if v is None or np.isnan(v) else float(v))
|
||||||
|
bw.close()
|
||||||
|
return np.array(vals)
|
||||||
|
|
||||||
|
|
||||||
|
def _observed_over_expected(matrix: np.ndarray) -> np.ndarray:
|
||||||
|
"""Distance-normalise a symmetric contact matrix (O/E transform)."""
|
||||||
|
n = matrix.shape[0]
|
||||||
|
oe = np.zeros((n, n), dtype=float)
|
||||||
|
for d in range(n):
|
||||||
|
idx = np.arange(n - d)
|
||||||
|
diag = matrix[idx, idx + d].astype(float)
|
||||||
|
positive = diag[diag > 0]
|
||||||
|
if positive.size == 0:
|
||||||
|
continue
|
||||||
|
mean_d = positive.mean()
|
||||||
|
norm_diag = np.where((np.isnan(diag)) | (diag == 0), 0.0, diag / mean_d)
|
||||||
|
oe[idx, idx + d] = norm_diag
|
||||||
|
if d > 0:
|
||||||
|
oe[idx + d, idx] = norm_diag
|
||||||
|
return oe
|
||||||
|
|
||||||
|
|
||||||
|
def compute_compartments(
|
||||||
|
mcool_path: str,
|
||||||
|
chrom: str,
|
||||||
|
res: int,
|
||||||
|
orient_signal=None,
|
||||||
|
flip_orient: bool = False,
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
Return a DataFrame (chrom, start, end, pc1, compartment).
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
orient_signal : array-like, optional
|
||||||
|
Per-bin 1-D signal used to fix the sign of PC1.
|
||||||
|
Pass CTCF signal for positive-PC1 = A convention.
|
||||||
|
flip_orient : bool
|
||||||
|
If True, high orient_signal maps to negative PC1 (use with H3K27me3).
|
||||||
|
"""
|
||||||
|
c = cooler.Cooler(f"{mcool_path}::resolutions/{res}")
|
||||||
|
bins_df = c.bins().fetch(chrom).reset_index(drop=True)
|
||||||
|
matrix = c.matrix(balance=True).fetch(chrom).astype(float)
|
||||||
|
|
||||||
|
bad_bins = np.isnan(matrix).all(axis=0) | (matrix.sum(axis=0) == 0)
|
||||||
|
np.nan_to_num(matrix, nan=0.0, copy=False)
|
||||||
|
|
||||||
|
oe = _observed_over_expected(matrix)
|
||||||
|
|
||||||
|
good = ~bad_bins
|
||||||
|
oe_good = oe[np.ix_(good, good)]
|
||||||
|
|
||||||
|
# Zero rows produce NaN in corrcoef; add tiny noise to avoid singularity
|
||||||
|
row_norms = np.linalg.norm(oe_good, axis=1)
|
||||||
|
oe_good[row_norms == 0] += 1e-9
|
||||||
|
|
||||||
|
corr = np.corrcoef(oe_good)
|
||||||
|
np.nan_to_num(corr, nan=0.0, copy=False)
|
||||||
|
|
||||||
|
pca = PCA(n_components=3, random_state=42)
|
||||||
|
pcs = pca.fit_transform(corr)
|
||||||
|
pc1_good = pcs[:, 0]
|
||||||
|
|
||||||
|
pc1 = np.full(len(bins_df), np.nan)
|
||||||
|
pc1[good] = pc1_good
|
||||||
|
|
||||||
|
if orient_signal is not None:
|
||||||
|
sig = np.asarray(orient_signal, dtype=float)
|
||||||
|
sig_good = sig[good]
|
||||||
|
valid = ~np.isnan(sig_good) & ~np.isnan(pc1_good)
|
||||||
|
if valid.sum() > 10:
|
||||||
|
r = np.corrcoef(pc1_good[valid], sig_good[valid])[0, 1]
|
||||||
|
# By default: positive orient_signal → positive PC1.
|
||||||
|
# flip_orient reverses this (e.g. H3K27me3 → positive PC1 = B).
|
||||||
|
if (r < 0 and not flip_orient) or (r > 0 and flip_orient):
|
||||||
|
pc1 = -pc1
|
||||||
|
|
||||||
|
bins_df["pc1"] = pc1
|
||||||
|
bins_df["compartment"] = np.where(
|
||||||
|
np.isnan(bins_df["pc1"]), "N",
|
||||||
|
np.where(bins_df["pc1"] > 0, "A", "B"),
|
||||||
|
)
|
||||||
|
return bins_df[["chrom", "start", "end", "pc1", "compartment"]]
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
p = argparse.ArgumentParser(
|
||||||
|
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
|
||||||
|
)
|
||||||
|
p.add_argument("--mcool", required=True, help="Path to .mcool file")
|
||||||
|
p.add_argument("--chrom", required=True, help="Chromosome (e.g. chr21)")
|
||||||
|
p.add_argument("--res", type=int, default=25000, help="Resolution in bp (default: 25000)")
|
||||||
|
p.add_argument("--bigwig_orient",
|
||||||
|
help="bigWig track for PC1 sign orientation (recommended: CTCF)")
|
||||||
|
p.add_argument("--flip_orient", action="store_true",
|
||||||
|
help="Flip orientation: high signal → negative PC1 (use with H3K27me3)")
|
||||||
|
p.add_argument("--out", required=True, help="Output CSV path")
|
||||||
|
args = p.parse_args()
|
||||||
|
|
||||||
|
orient_signal = None
|
||||||
|
if args.bigwig_orient:
|
||||||
|
c = cooler.Cooler(f"{args.mcool}::resolutions/{args.res}")
|
||||||
|
bins_df = c.bins().fetch(args.chrom).reset_index(drop=True)
|
||||||
|
coords = list(zip(bins_df["start"].values, bins_df["end"].values))
|
||||||
|
orient_signal = _bin_bigwig(args.bigwig_orient, args.chrom, coords)
|
||||||
|
print(f"Loaded orientation signal: {os.path.basename(args.bigwig_orient)}")
|
||||||
|
|
||||||
|
df = compute_compartments(
|
||||||
|
args.mcool, args.chrom, args.res,
|
||||||
|
orient_signal=orient_signal,
|
||||||
|
flip_orient=args.flip_orient,
|
||||||
|
)
|
||||||
|
|
||||||
|
os.makedirs(os.path.dirname(os.path.abspath(args.out)), exist_ok=True)
|
||||||
|
df.to_csv(args.out, index=False)
|
||||||
|
|
||||||
|
n_a = (df["compartment"] == "A").sum()
|
||||||
|
n_b = (df["compartment"] == "B").sum()
|
||||||
|
n_nan = (df["compartment"] == "N").sum()
|
||||||
|
print(f"Saved → {args.out}")
|
||||||
|
print(f" A: {n_a} B: {n_b} N/masked: {n_nan} bins")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -1,63 +1,91 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""
|
"""
|
||||||
Encode a new graph using a trained VGAE model.
|
Encode a chromatin contact graph using a trained VGAE model.
|
||||||
Automatically infers hidden/latent dimensions from saved weights.
|
|
||||||
|
Dimensions (in_dim, hidden, latent) are inferred automatically from the saved
|
||||||
|
state_dict. The BatchNorm running statistics from training are restored, so the
|
||||||
|
same normalisation is applied to held-out cell lines without a separate scaler.
|
||||||
|
|
||||||
|
Usage
|
||||||
|
-----
|
||||||
|
python scripts/encode_graph.py \\
|
||||||
|
--model results/GM12878/model.pt \\
|
||||||
|
--graph data/processed/IMR90_chr21.pt \\
|
||||||
|
--out results/IMR90/emb.npy
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse, torch, numpy as np
|
import argparse
|
||||||
from torch_geometric.nn import GCNConv, VGAE
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
# Reuse your Encoder definition directly here for clarity
|
import numpy as np
|
||||||
class Encoder(torch.nn.Module):
|
import torch
|
||||||
def __init__(self, in_dim, hidden, latent, dropout=0.2):
|
from torch_geometric.nn.models import VGAE
|
||||||
super().__init__()
|
|
||||||
self.gc1 = GCNConv(in_dim, hidden)
|
|
||||||
self.gc_mu = GCNConv(hidden, latent)
|
|
||||||
self.gc_log = GCNConv(hidden, latent)
|
|
||||||
self.dropout = dropout
|
|
||||||
|
|
||||||
def forward(self, x, edge_index):
|
sys.path.insert(0, os.path.dirname(__file__))
|
||||||
import torch.nn.functional as F
|
from model import Encoder
|
||||||
h = self.gc1(x, edge_index)
|
|
||||||
h = F.relu(h)
|
|
||||||
h = F.dropout(h, p=self.dropout, training=self.training)
|
def _infer_dims(state_dict: dict) -> tuple:
|
||||||
return self.gc_mu(h, edge_index), self.gc_log(h, edge_index)
|
"""Infer (in_dim, hidden, latent) from a VGAE state_dict."""
|
||||||
|
keys = list(state_dict.keys())
|
||||||
|
|
||||||
|
def _first_weight(substr):
|
||||||
|
for k in keys:
|
||||||
|
if (substr in k
|
||||||
|
and "weight" in k
|
||||||
|
and "running" not in k
|
||||||
|
and "num_batches" not in k):
|
||||||
|
return state_dict[k].shape
|
||||||
|
raise KeyError(f"No weight key containing '{substr}' in state_dict. "
|
||||||
|
f"Available keys: {keys}")
|
||||||
|
|
||||||
|
gc1_shape = _first_weight("gc1") # shape [hidden, in_dim]
|
||||||
|
gc_mu_shape = _first_weight("gc_mu") # shape [latent, hidden]
|
||||||
|
hidden = gc1_shape[0]
|
||||||
|
latent = gc_mu_shape[0]
|
||||||
|
|
||||||
|
# in_dim from BatchNorm weight (shape [in_dim])
|
||||||
|
for k in keys:
|
||||||
|
if "norm" in k and k.endswith("weight") and "running" not in k:
|
||||||
|
in_dim = state_dict[k].shape[0]
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
in_dim = gc1_shape[1] # fallback: second dim of gc1 weight
|
||||||
|
|
||||||
|
return in_dim, hidden, latent
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
p = argparse.ArgumentParser()
|
p = argparse.ArgumentParser(
|
||||||
p.add_argument("--model", required=True)
|
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
|
||||||
p.add_argument("--graph", required=True)
|
)
|
||||||
p.add_argument("--out", required=True)
|
p.add_argument("--model", required=True,
|
||||||
|
help="Path to model.pt saved by train_vgae.py")
|
||||||
|
p.add_argument("--graph", required=True,
|
||||||
|
help="Path to Data .pt file from build_graph.py")
|
||||||
|
p.add_argument("--out", required=True,
|
||||||
|
help="Output .npy path for node embeddings")
|
||||||
args = p.parse_args()
|
args = p.parse_args()
|
||||||
|
|
||||||
# ---- Load data and model state ----
|
data = torch.load(args.graph, weights_only=False)
|
||||||
data = torch.load(args.graph)
|
state_dict = torch.load(args.model, map_location="cpu", weights_only=False)
|
||||||
model_state = torch.load(args.model, map_location="cpu")
|
|
||||||
|
|
||||||
# ---- Infer dimensions dynamically ----
|
in_dim, hidden, latent = _infer_dims(state_dict)
|
||||||
in_dim = data.x.size(1)
|
print(f"Inferred: in_dim={in_dim} hidden={hidden} latent={latent}")
|
||||||
# detect hidden and latent dimensions safely
|
|
||||||
keys = list(model_state.keys())
|
|
||||||
gc1_weight = [k for k in keys if "gc1" in k and "weight" in k][0]
|
|
||||||
gc_mu_weight = [k for k in keys if "gc_mu" in k and "weight" in k][0]
|
|
||||||
|
|
||||||
hidden = model_state[gc1_weight].shape[0]
|
enc = Encoder(in_dim=in_dim, hidden=hidden, latent=latent)
|
||||||
latent = model_state[gc_mu_weight].shape[0]
|
|
||||||
|
|
||||||
print(f"Inferred dims: in={in_dim}, hidden={hidden}, latent={latent}")
|
|
||||||
|
|
||||||
enc = Encoder(in_dim=in_dim, hidden=hidden, latent=latent)
|
|
||||||
model = VGAE(enc)
|
model = VGAE(enc)
|
||||||
model.load_state_dict(model_state)
|
model.load_state_dict(state_dict)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
# ---- Encode ----
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
z = model.encode(data.x.float(), data.edge_index)
|
z = model.encode(data.x.float(), data.edge_index)
|
||||||
|
|
||||||
|
os.makedirs(os.path.dirname(os.path.abspath(args.out)), exist_ok=True)
|
||||||
np.save(args.out, z.cpu().numpy())
|
np.save(args.out, z.cpu().numpy())
|
||||||
print(f"Saved embeddings → {args.out} shape={z.shape}")
|
print(f"Saved embeddings → {args.out} shape={z.shape}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
32
scripts/model.py
Normal file
32
scripts/model.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Shared VGAE encoder. Imported by train_vgae.py and encode_graph.py."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch_geometric.nn import GCNConv
|
||||||
|
|
||||||
|
|
||||||
|
class Encoder(nn.Module):
|
||||||
|
"""Two-layer GCN encoder for VGAE with input BatchNorm.
|
||||||
|
|
||||||
|
Architecture: BatchNorm → GCN(hidden) → ReLU → Dropout → GCN_mu / GCN_logstd
|
||||||
|
|
||||||
|
The BatchNorm layer normalises raw ChIP-seq signals and its running statistics
|
||||||
|
are saved in model.pt, so encode_graph.py applies identical normalisation to
|
||||||
|
held-out cell lines without a separate scaler file.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, in_dim: int, hidden: int, latent: int, dropout: float = 0.2):
|
||||||
|
super().__init__()
|
||||||
|
self.norm = nn.BatchNorm1d(in_dim)
|
||||||
|
self.gc1 = GCNConv(in_dim, hidden, add_self_loops=True, normalize=True)
|
||||||
|
self.gc_mu = GCNConv(hidden, latent, add_self_loops=True, normalize=True)
|
||||||
|
self.gc_log = GCNConv(hidden, latent, add_self_loops=True, normalize=True)
|
||||||
|
self.dropout = dropout
|
||||||
|
|
||||||
|
def forward(self, x, edge_index):
|
||||||
|
x = self.norm(x)
|
||||||
|
h = F.relu(self.gc1(x, edge_index))
|
||||||
|
h = F.dropout(h, p=self.dropout, training=self.training)
|
||||||
|
return self.gc_mu(h, edge_index), self.gc_log(h, edge_index)
|
||||||
@@ -1,179 +1,185 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""
|
"""
|
||||||
Train a Variational Graph Autoencoder (VGAE) on a chromatin contact graph.
|
Train a Variational Graph Autoencoder (VGAE) on a chromatin contact graph.
|
||||||
---
|
|
||||||
Inputs:
|
Inputs
|
||||||
- A PyTorch Geometric Data object saved with torch.save()
|
------
|
||||||
- from build_graph.py
|
PyTorch Geometric Data object saved by build_graph.py.
|
||||||
---
|
|
||||||
Outputs (under results/):
|
Outputs (under --outdir)
|
||||||
- model.pt : trained VGAE state_dict
|
------------------------
|
||||||
- emb.npy : node embeddings (mean; shape [num_nodes, latent_dim])
|
model.pt trained VGAE state_dict (includes BatchNorm running statistics)
|
||||||
- metrics.json : train/val/test AUC/AP summary
|
emb.npy node embeddings — mu vector, shape [num_nodes, latent_dim]
|
||||||
|
metrics.json val/test AUC & AP plus all hyperparameters
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os, json, argparse, numpy as np, torch
|
import argparse
|
||||||
import torch.nn.functional as F
|
import json
|
||||||
from torch_geometric.nn import GCNConv
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
from torch_geometric.nn.models import VGAE
|
from torch_geometric.nn.models import VGAE
|
||||||
from torch_geometric.transforms import RandomLinkSplit
|
from torch_geometric.transforms import RandomLinkSplit
|
||||||
from torch_geometric.utils import to_undirected, remove_self_loops
|
from torch_geometric.utils import (
|
||||||
from torch_geometric.utils import negative_sampling
|
negative_sampling,
|
||||||
from sklearn.metrics import roc_auc_score, average_precision_score
|
remove_self_loops,
|
||||||
|
to_undirected,
|
||||||
|
)
|
||||||
|
from sklearn.metrics import average_precision_score, roc_auc_score
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.dirname(__file__))
|
||||||
class Encoder(torch.nn.Module):
|
from model import Encoder
|
||||||
def __init__(self, in_dim: int, hidden: int, latent: int, dropout: float = 0.2):
|
|
||||||
super().__init__()
|
|
||||||
self.gc1 = GCNConv(in_dim, hidden, add_self_loops=True, normalize=True)
|
|
||||||
self.gc_mu = GCNConv(hidden, latent, add_self_loops=True, normalize=True)
|
|
||||||
self.gc_log = GCNConv(hidden, latent, add_self_loops=True, normalize=True)
|
|
||||||
self.dropout = dropout
|
|
||||||
|
|
||||||
def forward(self, x, edge_index):
|
|
||||||
h = self.gc1(x, edge_index)
|
|
||||||
h = F.relu(h)
|
|
||||||
h = F.dropout(h, p=self.dropout, training=self.training)
|
|
||||||
return self.gc_mu(h, edge_index), self.gc_log(h, edge_index)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def eval_linkpred(model, data_like, z):
|
def _eval_linkpred(z, pos_edges, neg_edges):
|
||||||
"""Compute AUROC/AP using provided positive/negative edges."""
|
"""Return (AUROC, AP) for link prediction."""
|
||||||
pos = data_like.pos_edge_index
|
def _sigmoid(x):
|
||||||
neg = data_like.neg_edge_index
|
return 1.0 / (1.0 + torch.exp(-x))
|
||||||
# model.test returns (auc, ap) but relies on torchmetrics in some versions;
|
|
||||||
# compute explicitly for stability:
|
|
||||||
def sigmoid(x): return 1 / (1 + torch.exp(-x))
|
|
||||||
|
|
||||||
# Inner product decoder scores
|
def _score(edges):
|
||||||
def scores(edges):
|
|
||||||
src, dst = edges
|
src, dst = edges
|
||||||
s = (z[src] * z[dst]).sum(dim=1)
|
return _sigmoid((z[src] * z[dst]).sum(dim=1)).cpu().numpy()
|
||||||
return sigmoid(s).cpu().numpy()
|
|
||||||
|
|
||||||
y_true = np.concatenate([np.ones(pos.size(1)), np.zeros(neg.size(1))])
|
y_true = np.concatenate([np.ones(pos_edges.size(1)),
|
||||||
y_pred = np.concatenate([scores(pos), scores(neg)])
|
np.zeros(neg_edges.size(1))])
|
||||||
|
y_pred = np.concatenate([_score(pos_edges), _score(neg_edges)])
|
||||||
auc = roc_auc_score(y_true, y_pred)
|
return roc_auc_score(y_true, y_pred), average_precision_score(y_true, y_pred)
|
||||||
ap = average_precision_score(y_true, y_pred)
|
|
||||||
return auc, ap
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
ap = argparse.ArgumentParser()
|
ap = argparse.ArgumentParser(
|
||||||
ap.add_argument("--graph", required=True, help="Path to Data .pt file")
|
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
|
||||||
ap.add_argument("--epochs", type=int, default=100)
|
)
|
||||||
ap.add_argument("--lr", type=float, default=1e-3)
|
ap.add_argument("--graph", required=True,
|
||||||
ap.add_argument("--hidden", type=int, default=128)
|
help="Path to Data .pt file from build_graph.py")
|
||||||
ap.add_argument("--latent", type=int, default=64)
|
ap.add_argument("--epochs", type=int, default=300)
|
||||||
ap.add_argument("--dropout", type=float, default=0.2)
|
ap.add_argument("--patience", type=int, default=20,
|
||||||
ap.add_argument("--seed", type=int, default=42)
|
help="Early-stopping patience (val-AUC epochs without improvement)")
|
||||||
ap.add_argument("--outdir", default="results")
|
ap.add_argument("--lr", type=float, default=1e-3)
|
||||||
|
ap.add_argument("--hidden", type=int, default=64)
|
||||||
|
ap.add_argument("--latent", type=int, default=32)
|
||||||
|
ap.add_argument("--dropout", type=float, default=0.2)
|
||||||
|
ap.add_argument("--seed", type=int, default=42)
|
||||||
|
ap.add_argument("--outdir", default="results")
|
||||||
args = ap.parse_args()
|
args = ap.parse_args()
|
||||||
|
|
||||||
torch.manual_seed(args.seed)
|
torch.manual_seed(args.seed)
|
||||||
np.random.seed(args.seed)
|
np.random.seed(args.seed)
|
||||||
os.makedirs(args.outdir, exist_ok=True)
|
os.makedirs(args.outdir, exist_ok=True)
|
||||||
|
|
||||||
# Load graph
|
# ---- Load and clean graph ----
|
||||||
data = torch.load(args.graph)
|
data = torch.load(args.graph, weights_only=False)
|
||||||
# Coalesce/clean edges
|
|
||||||
ei, _ = remove_self_loops(data.edge_index)
|
ei, _ = remove_self_loops(data.edge_index)
|
||||||
data.edge_index = to_undirected(ei, num_nodes=data.num_nodes)
|
data.edge_index = to_undirected(ei, num_nodes=data.num_nodes)
|
||||||
x = data.x.float()
|
x = data.x.float()
|
||||||
|
print(f"Graph: {data.num_nodes} nodes "
|
||||||
|
f"{data.edge_index.shape[1]} edges "
|
||||||
|
f"{x.shape[1]} node features")
|
||||||
|
|
||||||
# Split edges for link prediction
|
# ---- Edge splits for link-prediction evaluation ----
|
||||||
splitter = RandomLinkSplit(
|
splitter = RandomLinkSplit(
|
||||||
num_val=0.1,
|
num_val=0.1, num_test=0.1,
|
||||||
num_test=0.1,
|
|
||||||
is_undirected=True,
|
is_undirected=True,
|
||||||
add_negative_train_samples=False,
|
add_negative_train_samples=False,
|
||||||
split_labels=False,
|
split_labels=False,
|
||||||
)
|
)
|
||||||
train_data, val_data, test_data = splitter(data)
|
train_data, val_data, test_data = splitter(data)
|
||||||
|
|
||||||
# Positive edges are just the edges in each split
|
|
||||||
train_data.pos_edge_index = train_data.edge_index
|
train_data.pos_edge_index = train_data.edge_index
|
||||||
val_data.pos_edge_index = val_data.edge_index
|
|
||||||
test_data.pos_edge_index = test_data.edge_index
|
|
||||||
|
|
||||||
# Generate negative edges for validation and test manually
|
for split in (val_data, test_data):
|
||||||
for subset in [val_data, test_data]:
|
split.pos_edge_index = split.edge_index
|
||||||
subset.neg_edge_index = negative_sampling(
|
split.neg_edge_index = negative_sampling(
|
||||||
edge_index=subset.edge_index,
|
edge_index=split.edge_index,
|
||||||
num_nodes=data.num_nodes,
|
num_nodes=data.num_nodes,
|
||||||
num_neg_samples=subset.edge_index.size(1),
|
num_neg_samples=split.edge_index.size(1),
|
||||||
method='sparse'
|
method="sparse",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# ---- Model ----
|
||||||
# Model
|
enc = Encoder(in_dim=x.size(1), hidden=args.hidden,
|
||||||
enc = Encoder(in_dim=x.size(1), hidden=args.hidden, latent=args.latent, dropout=args.dropout)
|
latent=args.latent, dropout=args.dropout)
|
||||||
model = VGAE(enc)
|
model = VGAE(enc)
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
|
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
|
||||||
|
|
||||||
# Training loop
|
# ---- Training loop with early stopping ----
|
||||||
best_val_auc = -1.0
|
best_val_auc = -1.0
|
||||||
best_state = None
|
best_state = None
|
||||||
|
no_improve = 0
|
||||||
|
epochs_ran = 0
|
||||||
|
|
||||||
for epoch in range(1, args.epochs + 1):
|
for epoch in range(1, args.epochs + 1):
|
||||||
model.train()
|
model.train()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
# Encode using remaining training edges
|
z = model.encode(x, train_data.edge_index)
|
||||||
z = model.encode(x, train_data.edge_index)
|
loss = (model.recon_loss(z, train_data.pos_edge_index)
|
||||||
# Reconstruction loss on positive training edges (negatives sampled inside)
|
+ (1.0 / data.num_nodes) * model.kl_loss())
|
||||||
loss_recon = model.recon_loss(z, train_data.pos_edge_index)
|
|
||||||
# KL divergence regularizer
|
|
||||||
loss_kl = (1.0 / data.num_nodes) * model.kl_loss()
|
|
||||||
loss = loss_recon + loss_kl
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
# Validation
|
|
||||||
model.eval()
|
model.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
z_full = model.encode(x, data.edge_index) # use full graph for eval embeddings
|
z_full = model.encode(x, data.edge_index)
|
||||||
val_auc, val_ap = eval_linkpred(model, val_data, z_full)
|
val_auc, val_ap = _eval_linkpred(
|
||||||
|
z_full, val_data.pos_edge_index, val_data.neg_edge_index
|
||||||
|
)
|
||||||
|
|
||||||
if val_auc > best_val_auc:
|
if val_auc > best_val_auc:
|
||||||
best_val_auc = val_auc
|
best_val_auc = val_auc
|
||||||
best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
|
best_state = {k: v.cpu().clone()
|
||||||
|
for k, v in model.state_dict().items()}
|
||||||
|
no_improve = 0
|
||||||
|
else:
|
||||||
|
no_improve += 1
|
||||||
|
|
||||||
|
epochs_ran = epoch
|
||||||
if epoch % 10 == 0 or epoch == 1:
|
if epoch % 10 == 0 or epoch == 1:
|
||||||
print(f"[{epoch:03d}/{args.epochs}] loss={loss.item():.4f} | val AUC={val_auc:.4f} AP={val_ap:.4f}")
|
print(f"[{epoch:03d}/{args.epochs}] "
|
||||||
|
f"loss={loss.item():.4f} "
|
||||||
|
f"val AUC={val_auc:.4f} AP={val_ap:.4f}")
|
||||||
|
|
||||||
# Save best model
|
if no_improve >= args.patience:
|
||||||
|
print(f"Early stopping at epoch {epoch} "
|
||||||
|
f"(no val-AUC improvement for {args.patience} epochs)")
|
||||||
|
break
|
||||||
|
|
||||||
|
# ---- Restore best checkpoint and compute test metrics ----
|
||||||
model.load_state_dict(best_state)
|
model.load_state_dict(best_state)
|
||||||
model_path = os.path.join(args.outdir, "model.pt")
|
|
||||||
torch.save(model.state_dict(), model_path)
|
|
||||||
|
|
||||||
# Final test metrics
|
|
||||||
model.eval()
|
model.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
z_final = model.encode(x, data.edge_index)
|
z_final = model.encode(x, data.edge_index)
|
||||||
test_auc, test_ap = eval_linkpred(model, test_data, z_final)
|
test_auc, test_ap = _eval_linkpred(
|
||||||
|
z_final, test_data.pos_edge_index, test_data.neg_edge_index
|
||||||
|
)
|
||||||
|
|
||||||
|
# ---- Save outputs ----
|
||||||
|
model_path = os.path.join(args.outdir, "model.pt")
|
||||||
|
torch.save(best_state, model_path)
|
||||||
|
|
||||||
# Save embeddings & metrics
|
|
||||||
emb_path = os.path.join(args.outdir, "emb.npy")
|
emb_path = os.path.join(args.outdir, "emb.npy")
|
||||||
np.save(emb_path, z_final.cpu().numpy())
|
np.save(emb_path, z_final.cpu().numpy())
|
||||||
|
|
||||||
metrics = {
|
metrics = {
|
||||||
"val_auc": float(best_val_auc),
|
"val_auc": float(best_val_auc),
|
||||||
"test_auc": float(test_auc),
|
"test_auc": float(test_auc),
|
||||||
"test_ap": float(test_ap),
|
"test_ap": float(test_ap),
|
||||||
"epochs": args.epochs,
|
"epochs_ran": epochs_ran,
|
||||||
"hidden": args.hidden,
|
"epochs_max": args.epochs,
|
||||||
"latent": args.latent,
|
"patience": args.patience,
|
||||||
"dropout": args.dropout,
|
"hidden": args.hidden,
|
||||||
"lr": args.lr,
|
"latent": args.latent,
|
||||||
"seed": args.seed
|
"dropout": args.dropout,
|
||||||
|
"lr": args.lr,
|
||||||
|
"seed": args.seed,
|
||||||
}
|
}
|
||||||
with open(os.path.join(args.outdir, "metrics.json"), "w") as f:
|
with open(os.path.join(args.outdir, "metrics.json"), "w") as f:
|
||||||
json.dump(metrics, f, indent=2)
|
json.dump(metrics, f, indent=2)
|
||||||
|
|
||||||
print(f"Saved model -> {model_path}")
|
print(f"\nSaved model → {model_path}")
|
||||||
print(f"Saved embeddings -> {emb_path} (shape={z_final.shape})")
|
print(f"Saved embeddings → {emb_path} shape={z_final.shape}")
|
||||||
print(f"Metrics: AUC(test)={test_auc:.4f}, AP(test)={test_ap:.4f}")
|
print(f"Test AUC={test_auc:.4f} AP={test_ap:.4f}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
189
scripts/visualize_embeddings.py
Normal file
189
scripts/visualize_embeddings.py
Normal file
@@ -0,0 +1,189 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Visualise VGAE node embeddings using UMAP.
|
||||||
|
|
||||||
|
Produces (under --prefix):
|
||||||
|
{prefix}_{label}_position.png UMAP coloured by genomic position (bin index)
|
||||||
|
{prefix}_{label}_compartment.png UMAP coloured by A/B compartment (needs --compartments)
|
||||||
|
{prefix}_joint.png Joint UMAP of all supplied cell lines
|
||||||
|
{prefix}_stats.csv Per-embedding summary statistics
|
||||||
|
|
||||||
|
Usage
|
||||||
|
-----
|
||||||
|
python scripts/visualize_embeddings.py \\
|
||||||
|
--emb results/GM12878/emb.npy results/IMR90/emb.npy \\
|
||||||
|
--labels GM12878 IMR90 \\
|
||||||
|
--compartments results/GM12878/compartments_chr21.csv \\
|
||||||
|
results/IMR90/compartments_chr21.csv \\
|
||||||
|
--prefix results/figures/umap
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
|
import matplotlib
|
||||||
|
matplotlib.use("Agg")
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
from sklearn.metrics import silhouette_score
|
||||||
|
import umap
|
||||||
|
|
||||||
|
|
||||||
|
COMPARTMENT_COLORS = {"A": "#E41A1C", "B": "#377EB8", "N": "#AAAAAA"}
|
||||||
|
CELL_LINE_PALETTE = ["#E41A1C", "#4DAF4A", "#984EA3", "#FF7F00", "#377EB8"]
|
||||||
|
|
||||||
|
plt.rcParams.update({
|
||||||
|
"font.family": "sans-serif",
|
||||||
|
"axes.spines.top": False,
|
||||||
|
"axes.spines.right": False,
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
def _run_umap(emb: np.ndarray, seed: int = 42) -> np.ndarray:
|
||||||
|
reducer = umap.UMAP(n_components=2, random_state=seed,
|
||||||
|
min_dist=0.3, n_neighbors=15)
|
||||||
|
return reducer.fit_transform(emb)
|
||||||
|
|
||||||
|
|
||||||
|
def _plot_position(coords: np.ndarray, label: str, out_path: str):
|
||||||
|
fig, ax = plt.subplots(figsize=(6.5, 5.5))
|
||||||
|
sc = ax.scatter(coords[:, 0], coords[:, 1],
|
||||||
|
c=np.arange(len(coords)), cmap="plasma",
|
||||||
|
s=4, alpha=0.75, linewidths=0, rasterized=True)
|
||||||
|
cbar = plt.colorbar(sc, ax=ax, shrink=0.8, pad=0.02)
|
||||||
|
cbar.set_label("Bin index (5′ → 3′)", fontsize=9)
|
||||||
|
ax.set_title(f"{label} — UMAP coloured by genomic position", fontsize=10)
|
||||||
|
ax.set_xlabel("UMAP 1", fontsize=9)
|
||||||
|
ax.set_ylabel("UMAP 2", fontsize=9)
|
||||||
|
ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(out_path, dpi=300)
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
|
||||||
|
def _plot_compartment(coords: np.ndarray, compartments: np.ndarray,
|
||||||
|
label: str, out_path: str):
|
||||||
|
fig, ax = plt.subplots(figsize=(6.5, 5.5))
|
||||||
|
for comp in ("A", "B", "N"):
|
||||||
|
mask = compartments == comp
|
||||||
|
if mask.sum() == 0:
|
||||||
|
continue
|
||||||
|
ax.scatter(coords[mask, 0], coords[mask, 1],
|
||||||
|
c=COMPARTMENT_COLORS[comp], s=4, alpha=0.75,
|
||||||
|
label=f"{comp} ({mask.sum()} bins)", linewidths=0,
|
||||||
|
rasterized=True)
|
||||||
|
ax.legend(markerscale=3, title="Compartment", fontsize=9,
|
||||||
|
title_fontsize=9, frameon=False)
|
||||||
|
ax.set_title(f"{label} — UMAP coloured by A/B compartment", fontsize=10)
|
||||||
|
ax.set_xlabel("UMAP 1", fontsize=9)
|
||||||
|
ax.set_ylabel("UMAP 2", fontsize=9)
|
||||||
|
ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(out_path, dpi=300)
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
|
||||||
|
def _plot_joint(all_coords: np.ndarray, all_labels: list, out_path: str):
|
||||||
|
fig, ax = plt.subplots(figsize=(7, 6))
|
||||||
|
unique = list(dict.fromkeys(all_labels))
|
||||||
|
arr = np.array(all_labels)
|
||||||
|
for i, label in enumerate(unique):
|
||||||
|
mask = arr == label
|
||||||
|
ax.scatter(all_coords[mask, 0], all_coords[mask, 1],
|
||||||
|
c=CELL_LINE_PALETTE[i % len(CELL_LINE_PALETTE)],
|
||||||
|
s=3, alpha=0.6, label=label, linewidths=0, rasterized=True)
|
||||||
|
ax.legend(markerscale=4, title="Cell line", fontsize=9,
|
||||||
|
title_fontsize=9, frameon=False)
|
||||||
|
ax.set_title("Joint UMAP — chromatin topology embeddings", fontsize=11)
|
||||||
|
ax.set_xlabel("UMAP 1", fontsize=9)
|
||||||
|
ax.set_ylabel("UMAP 2", fontsize=9)
|
||||||
|
ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(out_path, dpi=300)
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
|
||||||
|
def _silhouette(emb: np.ndarray, compartments: np.ndarray) -> float:
|
||||||
|
mask = compartments != "N"
|
||||||
|
if mask.sum() < 20 or len(set(compartments[mask])) < 2:
|
||||||
|
return float("nan")
|
||||||
|
try:
|
||||||
|
return float(silhouette_score(emb[mask], compartments[mask], metric="cosine"))
|
||||||
|
except Exception:
|
||||||
|
return float("nan")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
p = argparse.ArgumentParser(
|
||||||
|
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
|
||||||
|
)
|
||||||
|
p.add_argument("--emb", nargs="+", required=True,
|
||||||
|
help="One or more .npy embedding files")
|
||||||
|
p.add_argument("--labels", nargs="+", required=True,
|
||||||
|
help="Label for each embedding (same order)")
|
||||||
|
p.add_argument("--compartments", nargs="+",
|
||||||
|
help="Compartment CSV files, one per embedding (optional)")
|
||||||
|
p.add_argument("--prefix", default="results/figures/umap",
|
||||||
|
help="Output file prefix")
|
||||||
|
p.add_argument("--seed", type=int, default=42)
|
||||||
|
args = p.parse_args()
|
||||||
|
|
||||||
|
if len(args.emb) != len(args.labels):
|
||||||
|
raise ValueError("--emb and --labels must have the same length")
|
||||||
|
|
||||||
|
os.makedirs(os.path.dirname(os.path.abspath(args.prefix + "_x")), exist_ok=True)
|
||||||
|
|
||||||
|
embs = [np.load(f) for f in args.emb]
|
||||||
|
comp_dfs = []
|
||||||
|
if args.compartments:
|
||||||
|
for f in args.compartments:
|
||||||
|
comp_dfs.append(pd.read_csv(f) if (f and os.path.exists(f)) else None)
|
||||||
|
|
||||||
|
stats_rows = []
|
||||||
|
|
||||||
|
for i, (emb, label) in enumerate(zip(embs, args.labels)):
|
||||||
|
print(f"\n[{label}] {emb.shape[0]} nodes × {emb.shape[1]} dims")
|
||||||
|
coords = _run_umap(emb, seed=args.seed)
|
||||||
|
|
||||||
|
tag = label.replace(" ", "_")
|
||||||
|
_plot_position(coords, label, f"{args.prefix}_{tag}_position.png")
|
||||||
|
print(f" → {args.prefix}_{tag}_position.png")
|
||||||
|
|
||||||
|
comp_arr = None
|
||||||
|
sil = float("nan")
|
||||||
|
if comp_dfs and i < len(comp_dfs) and comp_dfs[i] is not None:
|
||||||
|
comp_arr = comp_dfs[i]["compartment"].values[: len(emb)]
|
||||||
|
_plot_compartment(coords, comp_arr, label,
|
||||||
|
f"{args.prefix}_{tag}_compartment.png")
|
||||||
|
print(f" → {args.prefix}_{tag}_compartment.png")
|
||||||
|
sil = _silhouette(emb, comp_arr)
|
||||||
|
print(f" Silhouette (A/B, cosine): {sil:.4f}")
|
||||||
|
|
||||||
|
stats_rows.append({
|
||||||
|
"label": label,
|
||||||
|
"n_bins": emb.shape[0],
|
||||||
|
"latent_dim": emb.shape[1],
|
||||||
|
"mean_embedding_norm": float(np.linalg.norm(emb, axis=1).mean()),
|
||||||
|
"std_embedding_values": float(emb.std()),
|
||||||
|
"silhouette_AB_cosine": sil,
|
||||||
|
})
|
||||||
|
|
||||||
|
# Joint UMAP when multiple embeddings are supplied
|
||||||
|
if len(embs) > 1:
|
||||||
|
print("\nComputing joint UMAP…")
|
||||||
|
all_emb = np.vstack(embs)
|
||||||
|
all_labels = sum([[lab] * len(e) for lab, e in zip(args.labels, embs)], [])
|
||||||
|
all_coords = _run_umap(all_emb, seed=args.seed)
|
||||||
|
_plot_joint(all_coords, all_labels, f"{args.prefix}_joint.png")
|
||||||
|
print(f" → {args.prefix}_joint.png")
|
||||||
|
|
||||||
|
stats_df = pd.DataFrame(stats_rows)
|
||||||
|
stats_path = f"{args.prefix}_stats.csv"
|
||||||
|
stats_df.to_csv(stats_path, index=False)
|
||||||
|
print(f"\nStats → {stats_path}")
|
||||||
|
print(stats_df.to_string(index=False))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user